mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
This PR fixes #69466 and introduces some other minor changes. Tests are somewhat more involved because a reference implementation in `scipy` is not available; tests proceed differently for discrete and continuous distributions. For continuous distributions, we evaluate the gradient of the `log_prob` at the mode. Tests pass if the gradient is zero OR (the mode is at the boundary of the support of the distribution AND the `log_prob` decreases as we move away from the boundary to the interior of the support). For discrete distributions, the notion of a gradient is not well defined. We thus "look" ahead and behind one step (e.g. if the mode of a Poisson distribution is 9, we consider 8 and 10). If the step ahead/behind is still within the support of the distribution, we assert that the `log_prob` is smaller than at the mode. For one-hot encoded distributions (currently just `OneHotCategorical`), we evaluate the underlying mode (i.e. encoded as an integral tensor), "advance" by one label to get another sample that should have lower probability using `other = (mode + 1) % event_size` and re-encode as one-hot. The resultant `other` sample should have lower probability than the mode. Furthermore, Gamma, half Cauchy, and half normal distributions have their support changed from positive to nonnegative. This change is necessary because the mode of the "half" distributions is zero, and the mode of the gamma distribution is zero for `concentration <= 1`. cc @fritzo Pull Request resolved: https://github.com/pytorch/pytorch/pull/76690 Approved by: https://github.com/neerajprad
75 lines
2.1 KiB
Python
75 lines
2.1 KiB
Python
from numbers import Number
|
|
|
|
import torch
|
|
from torch.distributions import constraints
|
|
from torch.distributions.exp_family import ExponentialFamily
|
|
from torch.distributions.utils import broadcast_all
|
|
|
|
|
|
class Poisson(ExponentialFamily):
|
|
r"""
|
|
Creates a Poisson distribution parameterized by :attr:`rate`, the rate parameter.
|
|
|
|
Samples are nonnegative integers, with a pmf given by
|
|
|
|
.. math::
|
|
\mathrm{rate}^k \frac{e^{-\mathrm{rate}}}{k!}
|
|
|
|
Example::
|
|
|
|
>>> m = Poisson(torch.tensor([4]))
|
|
>>> m.sample()
|
|
tensor([ 3.])
|
|
|
|
Args:
|
|
rate (Number, Tensor): the rate parameter
|
|
"""
|
|
arg_constraints = {'rate': constraints.nonnegative}
|
|
support = constraints.nonnegative_integer
|
|
|
|
@property
|
|
def mean(self):
|
|
return self.rate
|
|
|
|
@property
|
|
def mode(self):
|
|
return self.rate.floor()
|
|
|
|
@property
|
|
def variance(self):
|
|
return self.rate
|
|
|
|
def __init__(self, rate, validate_args=None):
|
|
self.rate, = broadcast_all(rate)
|
|
if isinstance(rate, Number):
|
|
batch_shape = torch.Size()
|
|
else:
|
|
batch_shape = self.rate.size()
|
|
super(Poisson, self).__init__(batch_shape, validate_args=validate_args)
|
|
|
|
def expand(self, batch_shape, _instance=None):
|
|
new = self._get_checked_instance(Poisson, _instance)
|
|
batch_shape = torch.Size(batch_shape)
|
|
new.rate = self.rate.expand(batch_shape)
|
|
super(Poisson, new).__init__(batch_shape, validate_args=False)
|
|
new._validate_args = self._validate_args
|
|
return new
|
|
|
|
def sample(self, sample_shape=torch.Size()):
|
|
shape = self._extended_shape(sample_shape)
|
|
with torch.no_grad():
|
|
return torch.poisson(self.rate.expand(shape))
|
|
|
|
def log_prob(self, value):
|
|
if self._validate_args:
|
|
self._validate_sample(value)
|
|
rate, value = broadcast_all(self.rate, value)
|
|
return value.xlogy(rate) - rate - (value + 1).lgamma()
|
|
|
|
@property
|
|
def _natural_params(self):
|
|
return (torch.log(self.rate), )
|
|
|
|
def _log_normalizer(self, x):
|
|
return torch.exp(x)
|