pytorch/torch/distributions/negative_binomial.py
Neeraj Pradhan 80fa8e1007 Add .expand() method to distribution classes (#11341)
Summary:
This adds a `.expand` method for distributions that is akin to the `torch.Tensor.expand` method for tensors. It returns a new distribution instance with batch dimensions expanded to the desired `batch_shape`. Since this calls `torch.Tensor.expand` on the distribution's parameters, it does not allocate new memory for the expanded distribution instance's parameters.

e.g.
```python
>>> d = dist.Normal(torch.zeros(100, 1), torch.ones(100, 1))
>>> d.sample().shape
  torch.Size([100, 1])
>>> d.expand([100, 10]).sample().shape
  torch.Size([100, 10])
```

We have already been using the `.expand` method in Pyro in our [patch](https://github.com/uber/pyro/blob/dev/pyro/distributions/torch.py#L10) of `torch.distributions`. We use this in our models to enable dynamic broadcasting. This has also been requested by a few users on the distributions slack, and we believe will be useful to the larger community.

Note that currently, there is no convenient and efficient way to expand distribution instances:
 - Many distributions use `TransformedDistribution` (or wrap over another distribution instance. e.g. `OneHotCategorical` uses a `Categorical` instance) under the hood, or have lazy parameters. This makes it difficult to collect all the relevant parameters, broadcast them and construct new instances.
 - In the few cases where this is even possible, the resulting implementation would be inefficient since we will go through a lot of broadcasting and args validation logic in `__init__.py` that can be avoided.

The `.expand` method allows for a safe and efficient way to expand distribution instances. Additionally, this bypasses `__init__.py` (using `__new__` and populating relevant attributes) since we do not need to do any broadcasting or args validation (which was already done when the instance was first created). This can result in significant savings as compared to constructing new instances via `__init__` (that said, the `sample` and `log_prob` methods will probably be the rate determining steps in many applications).

e.g.
```python
>>> a = dist.Bernoulli(torch.ones([10000, 1]), validate_args=True)

>>> %timeit a.expand([10000, 100])
15.2 µs ± 224 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

>>> %timeit dist.Bernoulli(torch.ones([10000, 100]), validate_args=True)
11.8 ms ± 153 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
```

cc. fritzo, apaszke, vishwakftw, alicanb
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11341

Differential Revision: D9728485

Pulled By: soumith

fbshipit-source-id: 3b94c23bc6a43ee704389e6287aa83d1e278d52f
2018-09-11 06:56:18 -07:00

99 lines
3.8 KiB
Python

import torch
import torch.nn.functional as F
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all, probs_to_logits, lazy_property, logits_to_probs
class NegativeBinomial(Distribution):
r"""
Creates a Negative Binomial distribution, i.e. distribution
of the number of independent identical Bernoulli trials
needed before :attr:`total_count` failures are achieved. The probability
of success of each Bernoulli trial is :attr:`probs`.
Args:
total_count (float or Tensor): non-negative number of negative Bernoulli
trials to stop, although the distribution is still valid for real
valued count
probs (Tensor): Event probabilities of success in the half open interval [0, 1)
logits (Tensor): Event log-odds for probabilities of success
"""
arg_constraints = {'total_count': constraints.greater_than_eq(0),
'probs': constraints.half_open_interval(0., 1.),
'logits': constraints.real}
support = constraints.nonnegative_integer
def __init__(self, total_count, probs=None, logits=None, validate_args=None):
if (probs is None) == (logits is None):
raise ValueError("Either `probs` or `logits` must be specified, but not both.")
if probs is not None:
self.total_count, self.probs, = broadcast_all(total_count, probs)
self.total_count = self.total_count.type_as(self.probs)
else:
self.total_count, self.logits, = broadcast_all(total_count, logits)
self.total_count = self.total_count.type_as(self.logits)
self._param = self.probs if probs is not None else self.logits
batch_shape = self._param.size()
super(NegativeBinomial, self).__init__(batch_shape, validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(NegativeBinomial, _instance)
batch_shape = torch.Size(batch_shape)
new.total_count = self.total_count.expand(batch_shape)
if 'probs' in self.__dict__:
new.probs = self.probs.expand(batch_shape)
new._param = new.probs
else:
new.logits = self.logits.expand(batch_shape)
new._param = new.logits
super(NegativeBinomial, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new
def _new(self, *args, **kwargs):
return self._param.new(*args, **kwargs)
@property
def mean(self):
return self.total_count * torch.exp(self.logits)
@property
def variance(self):
return self.mean / torch.sigmoid(-self.logits)
@lazy_property
def logits(self):
return probs_to_logits(self.probs, is_binary=True)
@lazy_property
def probs(self):
return logits_to_probs(self.logits, is_binary=True)
@property
def param_shape(self):
return self._param.size()
@lazy_property
def _gamma(self):
return torch.distributions.Gamma(concentration=self.total_count,
rate=torch.exp(-self.logits))
def sample(self, sample_shape=torch.Size()):
with torch.no_grad():
rate = self._gamma.sample(sample_shape=sample_shape)
return torch.poisson(rate)
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
log_unnormalized_prob = (self.total_count * F.logsigmoid(-self.logits) +
value * F.logsigmoid(self.logits))
log_normalization = (-torch.lgamma(self.total_count + value) + torch.lgamma(1. + value) +
torch.lgamma(self.total_count))
return log_unnormalized_prob - log_normalization