mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This is a new version of #15648 based on the latest master branch. Unlike the previous PR where I fixed a lot of the doctests in addition to integrating xdoctest, I'm going to reduce the scope here. I'm simply going to integrate xdoctest, and then I'm going to mark all of the failing tests as "SKIP". This will let xdoctest run on the dashboards, provide some value, and still let the dashboards pass. I'll leave fixing the doctests themselves to another PR. In my initial commit, I do the bare minimum to get something running with failing dashboards. The few tests that I marked as skip are causing segfaults. Running xdoctest results in 293 failed, 201 passed tests. The next commits will be to disable those tests. (unfortunately I don't have a tool that will insert the `#xdoctest: +SKIP` directive over every failing test, so I'm going to do this mostly manually.) Fixes https://github.com/pytorch/pytorch/issues/71105 @ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/82797 Approved by: https://github.com/ezyang
77 lines
2.2 KiB
Python
77 lines
2.2 KiB
Python
from numbers import Number
|
|
|
|
import torch
|
|
from torch.distributions import constraints
|
|
from torch.distributions.exp_family import ExponentialFamily
|
|
from torch.distributions.utils import broadcast_all
|
|
|
|
__all__ = ['Poisson']
|
|
|
|
class Poisson(ExponentialFamily):
|
|
r"""
|
|
Creates a Poisson distribution parameterized by :attr:`rate`, the rate parameter.
|
|
|
|
Samples are nonnegative integers, with a pmf given by
|
|
|
|
.. math::
|
|
\mathrm{rate}^k \frac{e^{-\mathrm{rate}}}{k!}
|
|
|
|
Example::
|
|
|
|
>>> # xdoctest: +SKIP("poisson_cpu not implemented for 'Long'")
|
|
>>> m = Poisson(torch.tensor([4]))
|
|
>>> m.sample()
|
|
tensor([ 3.])
|
|
|
|
Args:
|
|
rate (Number, Tensor): the rate parameter
|
|
"""
|
|
arg_constraints = {'rate': constraints.nonnegative}
|
|
support = constraints.nonnegative_integer
|
|
|
|
@property
|
|
def mean(self):
|
|
return self.rate
|
|
|
|
@property
|
|
def mode(self):
|
|
return self.rate.floor()
|
|
|
|
@property
|
|
def variance(self):
|
|
return self.rate
|
|
|
|
def __init__(self, rate, validate_args=None):
|
|
self.rate, = broadcast_all(rate)
|
|
if isinstance(rate, Number):
|
|
batch_shape = torch.Size()
|
|
else:
|
|
batch_shape = self.rate.size()
|
|
super(Poisson, self).__init__(batch_shape, validate_args=validate_args)
|
|
|
|
def expand(self, batch_shape, _instance=None):
|
|
new = self._get_checked_instance(Poisson, _instance)
|
|
batch_shape = torch.Size(batch_shape)
|
|
new.rate = self.rate.expand(batch_shape)
|
|
super(Poisson, new).__init__(batch_shape, validate_args=False)
|
|
new._validate_args = self._validate_args
|
|
return new
|
|
|
|
def sample(self, sample_shape=torch.Size()):
|
|
shape = self._extended_shape(sample_shape)
|
|
with torch.no_grad():
|
|
return torch.poisson(self.rate.expand(shape))
|
|
|
|
def log_prob(self, value):
|
|
if self._validate_args:
|
|
self._validate_sample(value)
|
|
rate, value = broadcast_all(self.rate, value)
|
|
return value.xlogy(rate) - rate - (value + 1).lgamma()
|
|
|
|
@property
|
|
def _natural_params(self):
|
|
return (torch.log(self.rate), )
|
|
|
|
def _log_normalizer(self, x):
|
|
return torch.exp(x)
|