mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings. I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :) Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519 Approved by: https://github.com/ezyang
126 lines
4.5 KiB
Python
126 lines
4.5 KiB
Python
from typing import Dict
|
|
|
|
import torch
|
|
from torch.distributions import constraints
|
|
from torch.distributions.distribution import Distribution
|
|
from torch.distributions.utils import _sum_rightmost
|
|
|
|
__all__ = ["Independent"]
|
|
|
|
|
|
class Independent(Distribution):
|
|
r"""
|
|
Reinterprets some of the batch dims of a distribution as event dims.
|
|
|
|
This is mainly useful for changing the shape of the result of
|
|
:meth:`log_prob`. For example to create a diagonal Normal distribution with
|
|
the same shape as a Multivariate Normal distribution (so they are
|
|
interchangeable), you can::
|
|
|
|
>>> from torch.distributions.multivariate_normal import MultivariateNormal
|
|
>>> from torch.distributions.normal import Normal
|
|
>>> loc = torch.zeros(3)
|
|
>>> scale = torch.ones(3)
|
|
>>> mvn = MultivariateNormal(loc, scale_tril=torch.diag(scale))
|
|
>>> [mvn.batch_shape, mvn.event_shape]
|
|
[torch.Size([]), torch.Size([3])]
|
|
>>> normal = Normal(loc, scale)
|
|
>>> [normal.batch_shape, normal.event_shape]
|
|
[torch.Size([3]), torch.Size([])]
|
|
>>> diagn = Independent(normal, 1)
|
|
>>> [diagn.batch_shape, diagn.event_shape]
|
|
[torch.Size([]), torch.Size([3])]
|
|
|
|
Args:
|
|
base_distribution (torch.distributions.distribution.Distribution): a
|
|
base distribution
|
|
reinterpreted_batch_ndims (int): the number of batch dims to
|
|
reinterpret as event dims
|
|
"""
|
|
arg_constraints: Dict[str, constraints.Constraint] = {}
|
|
|
|
def __init__(
|
|
self, base_distribution, reinterpreted_batch_ndims, validate_args=None
|
|
):
|
|
if reinterpreted_batch_ndims > len(base_distribution.batch_shape):
|
|
raise ValueError(
|
|
"Expected reinterpreted_batch_ndims <= len(base_distribution.batch_shape), "
|
|
f"actual {reinterpreted_batch_ndims} vs {len(base_distribution.batch_shape)}"
|
|
)
|
|
shape = base_distribution.batch_shape + base_distribution.event_shape
|
|
event_dim = reinterpreted_batch_ndims + len(base_distribution.event_shape)
|
|
batch_shape = shape[: len(shape) - event_dim]
|
|
event_shape = shape[len(shape) - event_dim :]
|
|
self.base_dist = base_distribution
|
|
self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
|
|
super().__init__(batch_shape, event_shape, validate_args=validate_args)
|
|
|
|
def expand(self, batch_shape, _instance=None):
|
|
new = self._get_checked_instance(Independent, _instance)
|
|
batch_shape = torch.Size(batch_shape)
|
|
new.base_dist = self.base_dist.expand(
|
|
batch_shape + self.event_shape[: self.reinterpreted_batch_ndims]
|
|
)
|
|
new.reinterpreted_batch_ndims = self.reinterpreted_batch_ndims
|
|
super(Independent, new).__init__(
|
|
batch_shape, self.event_shape, validate_args=False
|
|
)
|
|
new._validate_args = self._validate_args
|
|
return new
|
|
|
|
@property
|
|
def has_rsample(self):
|
|
return self.base_dist.has_rsample
|
|
|
|
@property
|
|
def has_enumerate_support(self):
|
|
if self.reinterpreted_batch_ndims > 0:
|
|
return False
|
|
return self.base_dist.has_enumerate_support
|
|
|
|
@constraints.dependent_property
|
|
def support(self):
|
|
result = self.base_dist.support
|
|
if self.reinterpreted_batch_ndims:
|
|
result = constraints.independent(result, self.reinterpreted_batch_ndims)
|
|
return result
|
|
|
|
@property
|
|
def mean(self):
|
|
return self.base_dist.mean
|
|
|
|
@property
|
|
def mode(self):
|
|
return self.base_dist.mode
|
|
|
|
@property
|
|
def variance(self):
|
|
return self.base_dist.variance
|
|
|
|
def sample(self, sample_shape=torch.Size()):
|
|
return self.base_dist.sample(sample_shape)
|
|
|
|
def rsample(self, sample_shape=torch.Size()):
|
|
return self.base_dist.rsample(sample_shape)
|
|
|
|
def log_prob(self, value):
|
|
log_prob = self.base_dist.log_prob(value)
|
|
return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims)
|
|
|
|
def entropy(self):
|
|
entropy = self.base_dist.entropy()
|
|
return _sum_rightmost(entropy, self.reinterpreted_batch_ndims)
|
|
|
|
def enumerate_support(self, expand=True):
|
|
if self.reinterpreted_batch_ndims > 0:
|
|
raise NotImplementedError(
|
|
"Enumeration over cartesian product is not implemented"
|
|
)
|
|
return self.base_dist.enumerate_support(expand=expand)
|
|
|
|
def __repr__(self):
|
|
return (
|
|
self.__class__.__name__
|
|
+ f"({self.base_dist}, {self.reinterpreted_batch_ndims})"
|
|
)
|