mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix docstring to clarify logits usage for multiclass case (#51053)
Summary: Fixes https://github.com/pytorch/pytorch/issues/50378. Additionally, this has some minor fixes: - [x] Fix mean for half-cauchy to return `inf` instead of `nan`. - [x] Fix constraints/support for the relaxed categorical distribution. Pull Request resolved: https://github.com/pytorch/pytorch/pull/51053 Reviewed By: heitorschueroff Differential Revision: D26077966 Pulled By: neerajprad fbshipit-source-id: ca0213baa9bbdbc661aebbb901ab5e7fded38a5f
This commit is contained in:
parent
221d7d99e1
commit
e2041ce354
|
|
@ -1511,7 +1511,7 @@ class TestDistributions(TestCase):
|
|||
def test_halfcauchy(self):
|
||||
scale = torch.ones(5, 5, requires_grad=True)
|
||||
scale_1d = torch.ones(1, requires_grad=True)
|
||||
self.assertTrue(is_all_nan(HalfCauchy(scale_1d).mean))
|
||||
self.assertTrue(torch.isinf(HalfCauchy(scale_1d).mean).all())
|
||||
self.assertEqual(HalfCauchy(scale_1d).variance, inf)
|
||||
self.assertEqual(HalfCauchy(scale).sample().size(), (5, 5))
|
||||
self.assertEqual(HalfCauchy(scale).sample((7,)).size(), (7, 5, 5))
|
||||
|
|
|
|||
|
|
@ -16,14 +16,19 @@ class Categorical(Distribution):
|
|||
|
||||
Samples are integers from :math:`\{0, \ldots, K-1\}` where `K` is ``probs.size(-1)``.
|
||||
|
||||
If :attr:`probs` is 1-dimensional with length-`K`, each element is the relative
|
||||
probability of sampling the class at that index.
|
||||
If `probs` is 1-dimensional with length-`K`, each element is the relative probability
|
||||
of sampling the class at that index.
|
||||
|
||||
If :attr:`probs` is N-dimensional, the first N-1 dimensions are treated as a batch of
|
||||
If `probs` is N-dimensional, the first N-1 dimensions are treated as a batch of
|
||||
relative probability vectors.
|
||||
|
||||
.. note:: :attr:`probs` must be non-negative, finite and have a non-zero sum,
|
||||
and it will be normalized to sum to 1 along the last dimension.
|
||||
.. note:: The `probs` argument must be non-negative, finite and have a non-zero sum,
|
||||
and it will be normalized to sum to 1 along the last dimension. attr:`probs`
|
||||
will return this normalized value.
|
||||
The `logits` argument will be interpreted as unnormalized log probabilities
|
||||
and can therefore be any real number. It will likewise be normalized so that
|
||||
the resulting probabilities sum to 1 along the last dimension. attr:`logits`
|
||||
will return this normalized value.
|
||||
|
||||
See also: :func:`torch.multinomial`
|
||||
|
||||
|
|
@ -35,7 +40,7 @@ class Categorical(Distribution):
|
|||
|
||||
Args:
|
||||
probs (Tensor): event probabilities
|
||||
logits (Tensor): event log-odds
|
||||
logits (Tensor): event log probabilities (unnormalized)
|
||||
"""
|
||||
arg_constraints = {'probs': constraints.simplex,
|
||||
'logits': constraints.real_vector}
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ class HalfCauchy(TransformedDistribution):
|
|||
|
||||
@property
|
||||
def mean(self):
|
||||
return self.base_dist.mean
|
||||
return torch.full(self._extended_shape(), math.inf, dtype=self.scale.dtype, device=self.scale.device)
|
||||
|
||||
@property
|
||||
def variance(self):
|
||||
|
|
|
|||
|
|
@ -15,8 +15,13 @@ class Multinomial(Distribution):
|
|||
Note that :attr:`total_count` need not be specified if only :meth:`log_prob` is
|
||||
called (see example below)
|
||||
|
||||
.. note:: :attr:`probs` must be non-negative, finite and have a non-zero sum,
|
||||
and it will be normalized to sum to 1.
|
||||
.. note:: The `probs` argument must be non-negative, finite and have a non-zero sum,
|
||||
and it will be normalized to sum to 1 along the last dimension. attr:`probs`
|
||||
will return this normalized value.
|
||||
The `logits` argument will be interpreted as unnormalized log probabilities
|
||||
and can therefore be any real number. It will likewise be normalized so that
|
||||
the resulting probabilities sum to 1 along the last dimension. attr:`logits`
|
||||
will return this normalized value.
|
||||
|
||||
- :meth:`sample` requires a single shared `total_count` for all
|
||||
parameters and samples.
|
||||
|
|
@ -35,7 +40,7 @@ class Multinomial(Distribution):
|
|||
Args:
|
||||
total_count (int): number of trials
|
||||
probs (Tensor): event probabilities
|
||||
logits (Tensor): event log probabilities
|
||||
logits (Tensor): event log probabilities (unnormalized)
|
||||
"""
|
||||
arg_constraints = {'probs': constraints.simplex,
|
||||
'logits': constraints.real_vector}
|
||||
|
|
|
|||
|
|
@ -11,8 +11,13 @@ class OneHotCategorical(Distribution):
|
|||
|
||||
Samples are one-hot coded vectors of size ``probs.size(-1)``.
|
||||
|
||||
.. note:: :attr:`probs` must be non-negative, finite and have a non-zero sum,
|
||||
and it will be normalized to sum to 1.
|
||||
.. note:: The `probs` argument must be non-negative, finite and have a non-zero sum,
|
||||
and it will be normalized to sum to 1 along the last dimension. attr:`probs`
|
||||
will return this normalized value.
|
||||
The `logits` argument will be interpreted as unnormalized log probabilities
|
||||
and can therefore be any real number. It will likewise be normalized so that
|
||||
the resulting probabilities sum to 1 along the last dimension. attr:`logits`
|
||||
will return this normalized value.
|
||||
|
||||
See also: :func:`torch.distributions.Categorical` for specifications of
|
||||
:attr:`probs` and :attr:`logits`.
|
||||
|
|
@ -25,7 +30,7 @@ class OneHotCategorical(Distribution):
|
|||
|
||||
Args:
|
||||
probs (Tensor): event probabilities
|
||||
logits (Tensor): event log probabilities
|
||||
logits (Tensor): event log probabilities (unnormalized)
|
||||
"""
|
||||
arg_constraints = {'probs': constraints.simplex,
|
||||
'logits': constraints.real_vector}
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ class ExpRelaxedCategorical(Distribution):
|
|||
Args:
|
||||
temperature (Tensor): relaxation temperature
|
||||
probs (Tensor): event probabilities
|
||||
logits (Tensor): the log probability of each event.
|
||||
logits (Tensor): unnormalized log probability for each event
|
||||
|
||||
[1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables
|
||||
(Maddison et al, 2017)
|
||||
|
|
@ -101,7 +101,7 @@ class RelaxedOneHotCategorical(TransformedDistribution):
|
|||
Args:
|
||||
temperature (Tensor): relaxation temperature
|
||||
probs (Tensor): event probabilities
|
||||
logits (Tensor): the log probability of each event.
|
||||
logits (Tensor): unnormalized log probability for each event
|
||||
"""
|
||||
arg_constraints = {'probs': constraints.simplex,
|
||||
'logits': constraints.real_vector}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user