mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fixes #144196 Extends #144106 and #144110 ## Open Problems: - [ ] Annotating with `numbers.Number` is a bad idea, should consider using `float`, `SupportsFloat` or some `Procotol`. https://github.com/pytorch/pytorch/pull/144197#discussion_r1903324769 # Notes - `beta.py`: needed to add `type: ignore` since `broadcast_all` is untyped. - `categorical.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2]. - ~~`dirichlet.py`: replaced `axis` with `dim` arguments.~~ #144402 - `gemoetric.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2]. - ~~`independent.py`: fixed bug in `Independent.__init__` where `tuple[int, ...]` could be passed to `Distribution.__init__` instead of `torch.Size`.~~ **EDIT:** turns out the bug is related to typing of `torch.Size`. #144218 - `independent.py`: made `Independent` a generic class of its base distribution. - `multivariate_normal.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2]. - `relaxed_bernoulli.py`: added class-level type hint for `base_dist`. - `relaxed_categorical.py`: added class-level type hint for `base_dist`. - ~~`transforms.py`: Added missing argument to docstring of `ReshapeTransform`~~ #144401 - ~~`transforms.py`: Fixed bug in `AffineTransform.sign` (could return `Tensor` instead of `int`).~~ #144400 - `transforms.py`: Added `type: ignore` comments to `AffineTransform.log_abs_det_jacobian`[^1]; replaced `torch.abs(scale)` with `scale.abs()`. - `transforms.py`: Added `type: ignore` comments to `AffineTransform.__eq__`[^1]. - `transforms.py`: Fixed type hint on `CumulativeDistributionTransform.domain`. Note that this is still an LSP violation, because `Transform.domain` is defined as `Constraint`, but `Distribution.domain` is defined as `Optional[Constraint]`. - skipped: `constraints.py`, `constraints_registry.py`, `kl.py`, `utils.py`, `exp_family.py`, `__init__.py`. ## Remark `TransformedDistribution`: `__init__` uses the check `if reinterpreted_batch_ndims > 0:`, which can lead to the creation of `Independent` distributions with only 1 component. This results in awkward code like `base_dist.base_dist` in `LogisticNormal`. ```python import torch from torch.distributions import * b1 = Normal(torch.tensor([0.0]), torch.tensor([1.0])) b2 = MultivariateNormal(torch.tensor([0.0]), torch.eye(1)) t = StickBreakingTransform() d1 = TransformedDistribution(b1, t) d2 = TransformedDistribution(b2, t) print(d1.base_dist) # Independent with 1 dimension print(d2.base_dist) # MultivariateNormal ``` One could consider changing this to `if reinterpreted_batch_ndims > 1:`. [^1]: Usage of `isinstance(value, numbers.Real)` leads to problems with static typing, as the `numbers` module is not supported by `mypy` (see <https://github.com/python/mypy/issues/3186>). This results in us having to add type-ignore comments in several places [^2]: Otherwise, we would have to add a bunch of `type: ignore` comments to make `mypy` happy, as it isn't able to perform the type narrowing. Ideally, such code should be replaced with structural pattern matching once support for Python 3.9 is dropped. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144197 Approved by: https://github.com/malfet Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
107 lines
3.6 KiB
Python
107 lines
3.6 KiB
Python
# mypy: allow-untyped-defs
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
from torch import nan, Tensor
|
|
from torch.distributions import constraints
|
|
from torch.distributions.transformed_distribution import TransformedDistribution
|
|
from torch.distributions.transforms import AffineTransform, PowerTransform
|
|
from torch.distributions.uniform import Uniform
|
|
from torch.distributions.utils import broadcast_all, euler_constant
|
|
|
|
|
|
__all__ = ["Kumaraswamy"]
|
|
|
|
|
|
def _moments(a, b, n):
|
|
"""
|
|
Computes nth moment of Kumaraswamy using using torch.lgamma
|
|
"""
|
|
arg1 = 1 + n / a
|
|
log_value = torch.lgamma(arg1) + torch.lgamma(b) - torch.lgamma(arg1 + b)
|
|
return b * torch.exp(log_value)
|
|
|
|
|
|
class Kumaraswamy(TransformedDistribution):
|
|
r"""
|
|
Samples from a Kumaraswamy distribution.
|
|
|
|
Example::
|
|
|
|
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
|
>>> m = Kumaraswamy(torch.tensor([1.0]), torch.tensor([1.0]))
|
|
>>> m.sample() # sample from a Kumaraswamy distribution with concentration alpha=1 and beta=1
|
|
tensor([ 0.1729])
|
|
|
|
Args:
|
|
concentration1 (float or Tensor): 1st concentration parameter of the distribution
|
|
(often referred to as alpha)
|
|
concentration0 (float or Tensor): 2nd concentration parameter of the distribution
|
|
(often referred to as beta)
|
|
"""
|
|
|
|
arg_constraints = {
|
|
"concentration1": constraints.positive,
|
|
"concentration0": constraints.positive,
|
|
}
|
|
support = constraints.unit_interval
|
|
has_rsample = True
|
|
|
|
def __init__(
|
|
self,
|
|
concentration1: Union[Tensor, float],
|
|
concentration0: Union[Tensor, float],
|
|
validate_args: Optional[bool] = None,
|
|
) -> None:
|
|
self.concentration1, self.concentration0 = broadcast_all(
|
|
concentration1, concentration0
|
|
)
|
|
base_dist = Uniform(
|
|
torch.full_like(self.concentration0, 0),
|
|
torch.full_like(self.concentration0, 1),
|
|
validate_args=validate_args,
|
|
)
|
|
transforms = [
|
|
PowerTransform(exponent=self.concentration0.reciprocal()),
|
|
AffineTransform(loc=1.0, scale=-1.0),
|
|
PowerTransform(exponent=self.concentration1.reciprocal()),
|
|
]
|
|
super().__init__(base_dist, transforms, validate_args=validate_args)
|
|
|
|
def expand(self, batch_shape, _instance=None):
|
|
new = self._get_checked_instance(Kumaraswamy, _instance)
|
|
new.concentration1 = self.concentration1.expand(batch_shape)
|
|
new.concentration0 = self.concentration0.expand(batch_shape)
|
|
return super().expand(batch_shape, _instance=new)
|
|
|
|
@property
|
|
def mean(self) -> Tensor:
|
|
return _moments(self.concentration1, self.concentration0, 1)
|
|
|
|
@property
|
|
def mode(self) -> Tensor:
|
|
# Evaluate in log-space for numerical stability.
|
|
log_mode = (
|
|
self.concentration0.reciprocal() * (-self.concentration0).log1p()
|
|
- (-self.concentration0 * self.concentration1).log1p()
|
|
)
|
|
log_mode[(self.concentration0 < 1) | (self.concentration1 < 1)] = nan
|
|
return log_mode.exp()
|
|
|
|
@property
|
|
def variance(self) -> Tensor:
|
|
return _moments(self.concentration1, self.concentration0, 2) - torch.pow(
|
|
self.mean, 2
|
|
)
|
|
|
|
def entropy(self):
|
|
t1 = 1 - self.concentration1.reciprocal()
|
|
t0 = 1 - self.concentration0.reciprocal()
|
|
H0 = torch.digamma(self.concentration0 + 1) + euler_constant
|
|
return (
|
|
t0
|
|
+ t1 * H0
|
|
- torch.log(self.concentration1)
|
|
- torch.log(self.concentration0)
|
|
)
|