Fix docstring to clarify logits usage for multiclass case (#51053)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/50378.

Additionally, this has some minor fixes:
 - [x] Fix mean for half-cauchy to return `inf` instead of `nan`.
 - [x] Fix constraints/support for the relaxed categorical distribution.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/51053

Reviewed By: heitorschueroff

Differential Revision: D26077966

Pulled By: neerajprad

fbshipit-source-id: ca0213baa9bbdbc661aebbb901ab5e7fded38a5f
This commit is contained in:
neerajprad 2021-01-26 16:59:24 -08:00 committed by Facebook GitHub Bot
parent 221d7d99e1
commit e2041ce354
6 changed files with 31 additions and 16 deletions

View File

@ -1511,7 +1511,7 @@ class TestDistributions(TestCase):
def test_halfcauchy(self):
scale = torch.ones(5, 5, requires_grad=True)
scale_1d = torch.ones(1, requires_grad=True)
self.assertTrue(is_all_nan(HalfCauchy(scale_1d).mean))
self.assertTrue(torch.isinf(HalfCauchy(scale_1d).mean).all())
self.assertEqual(HalfCauchy(scale_1d).variance, inf)
self.assertEqual(HalfCauchy(scale).sample().size(), (5, 5))
self.assertEqual(HalfCauchy(scale).sample((7,)).size(), (7, 5, 5))

View File

@ -16,14 +16,19 @@ class Categorical(Distribution):
Samples are integers from :math:`\{0, \ldots, K-1\}` where `K` is ``probs.size(-1)``.
If :attr:`probs` is 1-dimensional with length-`K`, each element is the relative
probability of sampling the class at that index.
If `probs` is 1-dimensional with length-`K`, each element is the relative probability
of sampling the class at that index.
If :attr:`probs` is N-dimensional, the first N-1 dimensions are treated as a batch of
If `probs` is N-dimensional, the first N-1 dimensions are treated as a batch of
relative probability vectors.
.. note:: :attr:`probs` must be non-negative, finite and have a non-zero sum,
and it will be normalized to sum to 1 along the last dimension.
.. note:: The `probs` argument must be non-negative, finite and have a non-zero sum,
and it will be normalized to sum to 1 along the last dimension. attr:`probs`
will return this normalized value.
The `logits` argument will be interpreted as unnormalized log probabilities
and can therefore be any real number. It will likewise be normalized so that
the resulting probabilities sum to 1 along the last dimension. attr:`logits`
will return this normalized value.
See also: :func:`torch.multinomial`
@ -35,7 +40,7 @@ class Categorical(Distribution):
Args:
probs (Tensor): event probabilities
logits (Tensor): event log-odds
logits (Tensor): event log probabilities (unnormalized)
"""
arg_constraints = {'probs': constraints.simplex,
'logits': constraints.real_vector}

View File

@ -43,7 +43,7 @@ class HalfCauchy(TransformedDistribution):
@property
def mean(self):
return self.base_dist.mean
return torch.full(self._extended_shape(), math.inf, dtype=self.scale.dtype, device=self.scale.device)
@property
def variance(self):

View File

@ -15,8 +15,13 @@ class Multinomial(Distribution):
Note that :attr:`total_count` need not be specified if only :meth:`log_prob` is
called (see example below)
.. note:: :attr:`probs` must be non-negative, finite and have a non-zero sum,
and it will be normalized to sum to 1.
.. note:: The `probs` argument must be non-negative, finite and have a non-zero sum,
and it will be normalized to sum to 1 along the last dimension. attr:`probs`
will return this normalized value.
The `logits` argument will be interpreted as unnormalized log probabilities
and can therefore be any real number. It will likewise be normalized so that
the resulting probabilities sum to 1 along the last dimension. attr:`logits`
will return this normalized value.
- :meth:`sample` requires a single shared `total_count` for all
parameters and samples.
@ -35,7 +40,7 @@ class Multinomial(Distribution):
Args:
total_count (int): number of trials
probs (Tensor): event probabilities
logits (Tensor): event log probabilities
logits (Tensor): event log probabilities (unnormalized)
"""
arg_constraints = {'probs': constraints.simplex,
'logits': constraints.real_vector}

View File

@ -11,8 +11,13 @@ class OneHotCategorical(Distribution):
Samples are one-hot coded vectors of size ``probs.size(-1)``.
.. note:: :attr:`probs` must be non-negative, finite and have a non-zero sum,
and it will be normalized to sum to 1.
.. note:: The `probs` argument must be non-negative, finite and have a non-zero sum,
and it will be normalized to sum to 1 along the last dimension. attr:`probs`
will return this normalized value.
The `logits` argument will be interpreted as unnormalized log probabilities
and can therefore be any real number. It will likewise be normalized so that
the resulting probabilities sum to 1 along the last dimension. attr:`logits`
will return this normalized value.
See also: :func:`torch.distributions.Categorical` for specifications of
:attr:`probs` and :attr:`logits`.
@ -25,7 +30,7 @@ class OneHotCategorical(Distribution):
Args:
probs (Tensor): event probabilities
logits (Tensor): event log probabilities
logits (Tensor): event log probabilities (unnormalized)
"""
arg_constraints = {'probs': constraints.simplex,
'logits': constraints.real_vector}

View File

@ -21,7 +21,7 @@ class ExpRelaxedCategorical(Distribution):
Args:
temperature (Tensor): relaxation temperature
probs (Tensor): event probabilities
logits (Tensor): the log probability of each event.
logits (Tensor): unnormalized log probability for each event
[1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables
(Maddison et al, 2017)
@ -101,7 +101,7 @@ class RelaxedOneHotCategorical(TransformedDistribution):
Args:
temperature (Tensor): relaxation temperature
probs (Tensor): event probabilities
logits (Tensor): the log probability of each event.
logits (Tensor): unnormalized log probability for each event
"""
arg_constraints = {'probs': constraints.simplex,
'logits': constraints.real_vector}