mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix to distribution.__repr__ with lazy attributes (#11263)
Summary: `__repr__` currently fails for distributions with lazy attributes in PyTorch master, throwing a `KeyError`. This fixes the issue. **Additionally:** - Added `logits` to `arg_constraints` for distributions that accept either `probs` or `logits`. This is both to have `__repr__` display the `logits` param when available, and to be able to do validation checks (e.g. NaN checks) when the logit parametrization is used. fritzo, alicanb - I think there were reasons why we had not done so in the first place, but I am unable to recall now. It passes all the tests, but let me know if there is something that I am missing at the moment. - There are certain distributions, e.g. `OneHotCategorical` which won't show any parameters because it uses a `categorical` instance under the hood and neither `logits` / `probs` in `arg_constraints` are present in the instance's `__dict__`. This isn't addressed in this PR. cc. vishwakftw, fritzo, nadavbh12, apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/11263 Differential Revision: D9654959 Pulled By: apaszke fbshipit-source-id: 16f5b20243fe8e2c13e9c528050d4df0b8ea6e45
This commit is contained in:
parent
9fc22cb772
commit
434e943b08
|
|
@ -93,6 +93,7 @@ EXAMPLES = [
|
|||
{'probs': torch.tensor([0.7, 0.2, 0.4], requires_grad=True)},
|
||||
{'probs': torch.tensor([0.3], requires_grad=True)},
|
||||
{'probs': 0.3},
|
||||
{'logits': torch.tensor([0.], requires_grad=True)},
|
||||
]),
|
||||
Example(Geometric, [
|
||||
{'probs': torch.tensor([0.7, 0.2, 0.4], requires_grad=True)},
|
||||
|
|
@ -112,6 +113,7 @@ EXAMPLES = [
|
|||
Example(Categorical, [
|
||||
{'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True)},
|
||||
{'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)},
|
||||
{'logits': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)},
|
||||
]),
|
||||
Example(Binomial, [
|
||||
{'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True), 'total_count': 10},
|
||||
|
|
@ -322,6 +324,7 @@ EXAMPLES = [
|
|||
Example(OneHotCategorical, [
|
||||
{'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True)},
|
||||
{'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)},
|
||||
{'logits': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)},
|
||||
]),
|
||||
Example(Pareto, [
|
||||
{
|
||||
|
|
@ -713,6 +716,12 @@ class TestDistributions(TestCase):
|
|||
actual = dist(param).enumerate_support()
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_repr(self):
|
||||
for Dist, params in EXAMPLES:
|
||||
for param in params:
|
||||
dist = Dist(**param)
|
||||
self.assertTrue(repr(dist).startswith(dist.__class__.__name__))
|
||||
|
||||
def test_sample_detached(self):
|
||||
for Dist, params in EXAMPLES:
|
||||
for i, param in enumerate(params):
|
||||
|
|
|
|||
|
|
@ -24,7 +24,8 @@ class Bernoulli(ExponentialFamily):
|
|||
probs (Number, Tensor): the probabilty of sampling `1`
|
||||
logits (Number, Tensor): the log-odds of sampling `1`
|
||||
"""
|
||||
arg_constraints = {'probs': constraints.unit_interval}
|
||||
arg_constraints = {'probs': constraints.unit_interval,
|
||||
'logits': constraints.real}
|
||||
support = constraints.boolean
|
||||
has_enumerate_support = True
|
||||
_mean_carrier_measure = 0
|
||||
|
|
|
|||
|
|
@ -28,7 +28,8 @@ class Binomial(Distribution):
|
|||
logits (Tensor): Event log-odds
|
||||
"""
|
||||
arg_constraints = {'total_count': constraints.nonnegative_integer,
|
||||
'probs': constraints.unit_interval}
|
||||
'probs': constraints.unit_interval,
|
||||
'logits': constraints.real}
|
||||
has_enumerate_support = True
|
||||
|
||||
def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
|
||||
|
|
|
|||
|
|
@ -37,7 +37,8 @@ class Categorical(Distribution):
|
|||
probs (Tensor): event probabilities
|
||||
logits (Tensor): event log probabilities
|
||||
"""
|
||||
arg_constraints = {'probs': constraints.simplex}
|
||||
arg_constraints = {'probs': constraints.simplex,
|
||||
'logits': constraints.real}
|
||||
has_enumerate_support = True
|
||||
|
||||
def __init__(self, probs=None, logits=None, validate_args=None):
|
||||
|
|
|
|||
|
|
@ -221,7 +221,7 @@ class Distribution(object):
|
|||
raise ValueError('The value argument must be within the support')
|
||||
|
||||
def __repr__(self):
|
||||
param_names = [k for k, _ in self.arg_constraints.items()]
|
||||
param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__]
|
||||
args_string = ', '.join(['{}: {}'.format(p, self.__dict__[p]
|
||||
if self.__dict__[p].dim() == 0
|
||||
else self.__dict__[p].size()) for p in param_names])
|
||||
|
|
|
|||
|
|
@ -25,7 +25,8 @@ class Geometric(Distribution):
|
|||
probs (Number, Tensor): the probabilty of sampling `1`. Must be in range (0, 1]
|
||||
logits (Number, Tensor): the log-odds of sampling `1`.
|
||||
"""
|
||||
arg_constraints = {'probs': constraints.unit_interval}
|
||||
arg_constraints = {'probs': constraints.unit_interval,
|
||||
'logits': constraints.real}
|
||||
support = constraints.nonnegative_integer
|
||||
|
||||
def __init__(self, probs=None, logits=None, validate_args=None):
|
||||
|
|
|
|||
|
|
@ -38,7 +38,8 @@ class Multinomial(Distribution):
|
|||
probs (Tensor): event probabilities
|
||||
logits (Tensor): event log probabilities
|
||||
"""
|
||||
arg_constraints = {'logits': constraints.real} # Let logits be the canonical parameterization.
|
||||
arg_constraints = {'probs': constraints.simplex,
|
||||
'logits': constraints.real}
|
||||
|
||||
@property
|
||||
def mean(self):
|
||||
|
|
|
|||
|
|
@ -20,7 +20,8 @@ class NegativeBinomial(Distribution):
|
|||
logits (Tensor): Event log-odds for probabilities of success
|
||||
"""
|
||||
arg_constraints = {'total_count': constraints.greater_than_eq(0),
|
||||
'probs': constraints.half_open_interval(0., 1.)}
|
||||
'probs': constraints.half_open_interval(0., 1.),
|
||||
'logits': constraints.real}
|
||||
support = constraints.nonnegative_integer
|
||||
|
||||
def __init__(self, total_count, probs=None, logits=None, validate_args=None):
|
||||
|
|
|
|||
|
|
@ -27,7 +27,8 @@ class OneHotCategorical(Distribution):
|
|||
probs (Tensor): event probabilities
|
||||
logits (Tensor): event log probabilities
|
||||
"""
|
||||
arg_constraints = {'probs': constraints.simplex}
|
||||
arg_constraints = {'probs': constraints.simplex,
|
||||
'logits': constraints.real}
|
||||
support = constraints.simplex
|
||||
has_enumerate_support = True
|
||||
|
||||
|
|
|
|||
|
|
@ -25,7 +25,8 @@ class LogitRelaxedBernoulli(Distribution):
|
|||
[2] Categorical Reparametrization with Gumbel-Softmax
|
||||
(Jang et al, 2017)
|
||||
"""
|
||||
arg_constraints = {'probs': constraints.unit_interval}
|
||||
arg_constraints = {'probs': constraints.unit_interval,
|
||||
'logits': constraints.real}
|
||||
support = constraints.real
|
||||
|
||||
def __init__(self, temperature, probs=None, logits=None, validate_args=None):
|
||||
|
|
@ -92,7 +93,8 @@ class RelaxedBernoulli(TransformedDistribution):
|
|||
probs (Number, Tensor): the probabilty of sampling `1`
|
||||
logits (Number, Tensor): the log-odds of sampling `1`
|
||||
"""
|
||||
arg_constraints = {'probs': constraints.unit_interval}
|
||||
arg_constraints = {'probs': constraints.unit_interval,
|
||||
'logits': constraints.real}
|
||||
support = constraints.unit_interval
|
||||
has_rsample = True
|
||||
|
||||
|
|
|
|||
|
|
@ -29,7 +29,8 @@ class ExpRelaxedCategorical(Distribution):
|
|||
[2] Categorical Reparametrization with Gumbel-Softmax
|
||||
(Jang et al, 2017)
|
||||
"""
|
||||
arg_constraints = {'probs': constraints.simplex}
|
||||
arg_constraints = {'probs': constraints.simplex,
|
||||
'logits': constraints.real}
|
||||
support = constraints.real
|
||||
has_rsample = True
|
||||
|
||||
|
|
@ -93,7 +94,8 @@ class RelaxedOneHotCategorical(TransformedDistribution):
|
|||
probs (Tensor): event probabilities
|
||||
logits (Tensor): the log probability of each event.
|
||||
"""
|
||||
arg_constraints = {'probs': constraints.simplex}
|
||||
arg_constraints = {'probs': constraints.simplex,
|
||||
'logits': constraints.real}
|
||||
support = constraints.simplex
|
||||
has_rsample = True
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user