Fix to distribution.__repr__ with lazy attributes (#11263)

Summary:
`__repr__` currently fails for distributions with lazy attributes in PyTorch master, throwing a `KeyError`. This fixes the issue.

**Additionally:**
 - Added `logits` to `arg_constraints` for distributions that accept either `probs` or `logits`. This is both to have `__repr__` display the `logits` param when available, and to be able to do validation checks (e.g. NaN checks) when the logit parametrization is used. fritzo, alicanb - I think there were reasons why we had not done so in the first place, but I am unable to recall now. It passes all the tests, but let me know if there is something that I am missing at the moment.
 - There are certain distributions, e.g. `OneHotCategorical` which won't show any parameters because it uses a `categorical` instance under the hood and neither `logits` / `probs` in `arg_constraints` are present in the instance's `__dict__`. This isn't addressed in this PR.

cc. vishwakftw, fritzo, nadavbh12, apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11263

Differential Revision: D9654959

Pulled By: apaszke

fbshipit-source-id: 16f5b20243fe8e2c13e9c528050d4df0b8ea6e45
This commit is contained in:
Neeraj Pradhan 2018-09-05 09:44:54 -07:00 committed by Facebook Github Bot
parent 9fc22cb772
commit 434e943b08
11 changed files with 32 additions and 12 deletions

View File

@ -93,6 +93,7 @@ EXAMPLES = [
{'probs': torch.tensor([0.7, 0.2, 0.4], requires_grad=True)},
{'probs': torch.tensor([0.3], requires_grad=True)},
{'probs': 0.3},
{'logits': torch.tensor([0.], requires_grad=True)},
]),
Example(Geometric, [
{'probs': torch.tensor([0.7, 0.2, 0.4], requires_grad=True)},
@ -112,6 +113,7 @@ EXAMPLES = [
Example(Categorical, [
{'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True)},
{'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)},
{'logits': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)},
]),
Example(Binomial, [
{'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True), 'total_count': 10},
@ -322,6 +324,7 @@ EXAMPLES = [
Example(OneHotCategorical, [
{'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True)},
{'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)},
{'logits': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)},
]),
Example(Pareto, [
{
@ -713,6 +716,12 @@ class TestDistributions(TestCase):
actual = dist(param).enumerate_support()
self.assertEqual(actual, expected)
def test_repr(self):
for Dist, params in EXAMPLES:
for param in params:
dist = Dist(**param)
self.assertTrue(repr(dist).startswith(dist.__class__.__name__))
def test_sample_detached(self):
for Dist, params in EXAMPLES:
for i, param in enumerate(params):

View File

@ -24,7 +24,8 @@ class Bernoulli(ExponentialFamily):
probs (Number, Tensor): the probabilty of sampling `1`
logits (Number, Tensor): the log-odds of sampling `1`
"""
arg_constraints = {'probs': constraints.unit_interval}
arg_constraints = {'probs': constraints.unit_interval,
'logits': constraints.real}
support = constraints.boolean
has_enumerate_support = True
_mean_carrier_measure = 0

View File

@ -28,7 +28,8 @@ class Binomial(Distribution):
logits (Tensor): Event log-odds
"""
arg_constraints = {'total_count': constraints.nonnegative_integer,
'probs': constraints.unit_interval}
'probs': constraints.unit_interval,
'logits': constraints.real}
has_enumerate_support = True
def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):

View File

@ -37,7 +37,8 @@ class Categorical(Distribution):
probs (Tensor): event probabilities
logits (Tensor): event log probabilities
"""
arg_constraints = {'probs': constraints.simplex}
arg_constraints = {'probs': constraints.simplex,
'logits': constraints.real}
has_enumerate_support = True
def __init__(self, probs=None, logits=None, validate_args=None):

View File

@ -221,7 +221,7 @@ class Distribution(object):
raise ValueError('The value argument must be within the support')
def __repr__(self):
param_names = [k for k, _ in self.arg_constraints.items()]
param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__]
args_string = ', '.join(['{}: {}'.format(p, self.__dict__[p]
if self.__dict__[p].dim() == 0
else self.__dict__[p].size()) for p in param_names])

View File

@ -25,7 +25,8 @@ class Geometric(Distribution):
probs (Number, Tensor): the probabilty of sampling `1`. Must be in range (0, 1]
logits (Number, Tensor): the log-odds of sampling `1`.
"""
arg_constraints = {'probs': constraints.unit_interval}
arg_constraints = {'probs': constraints.unit_interval,
'logits': constraints.real}
support = constraints.nonnegative_integer
def __init__(self, probs=None, logits=None, validate_args=None):

View File

@ -38,7 +38,8 @@ class Multinomial(Distribution):
probs (Tensor): event probabilities
logits (Tensor): event log probabilities
"""
arg_constraints = {'logits': constraints.real} # Let logits be the canonical parameterization.
arg_constraints = {'probs': constraints.simplex,
'logits': constraints.real}
@property
def mean(self):

View File

@ -20,7 +20,8 @@ class NegativeBinomial(Distribution):
logits (Tensor): Event log-odds for probabilities of success
"""
arg_constraints = {'total_count': constraints.greater_than_eq(0),
'probs': constraints.half_open_interval(0., 1.)}
'probs': constraints.half_open_interval(0., 1.),
'logits': constraints.real}
support = constraints.nonnegative_integer
def __init__(self, total_count, probs=None, logits=None, validate_args=None):

View File

@ -27,7 +27,8 @@ class OneHotCategorical(Distribution):
probs (Tensor): event probabilities
logits (Tensor): event log probabilities
"""
arg_constraints = {'probs': constraints.simplex}
arg_constraints = {'probs': constraints.simplex,
'logits': constraints.real}
support = constraints.simplex
has_enumerate_support = True

View File

@ -25,7 +25,8 @@ class LogitRelaxedBernoulli(Distribution):
[2] Categorical Reparametrization with Gumbel-Softmax
(Jang et al, 2017)
"""
arg_constraints = {'probs': constraints.unit_interval}
arg_constraints = {'probs': constraints.unit_interval,
'logits': constraints.real}
support = constraints.real
def __init__(self, temperature, probs=None, logits=None, validate_args=None):
@ -92,7 +93,8 @@ class RelaxedBernoulli(TransformedDistribution):
probs (Number, Tensor): the probabilty of sampling `1`
logits (Number, Tensor): the log-odds of sampling `1`
"""
arg_constraints = {'probs': constraints.unit_interval}
arg_constraints = {'probs': constraints.unit_interval,
'logits': constraints.real}
support = constraints.unit_interval
has_rsample = True

View File

@ -29,7 +29,8 @@ class ExpRelaxedCategorical(Distribution):
[2] Categorical Reparametrization with Gumbel-Softmax
(Jang et al, 2017)
"""
arg_constraints = {'probs': constraints.simplex}
arg_constraints = {'probs': constraints.simplex,
'logits': constraints.real}
support = constraints.real
has_rsample = True
@ -93,7 +94,8 @@ class RelaxedOneHotCategorical(TransformedDistribution):
probs (Tensor): event probabilities
logits (Tensor): the log probability of each event.
"""
arg_constraints = {'probs': constraints.simplex}
arg_constraints = {'probs': constraints.simplex,
'logits': constraints.real}
support = constraints.simplex
has_rsample = True