Fixes#144196
Extends #144106 and #144110
## Open Problems:
- [ ] Annotating with `numbers.Number` is a bad idea, should consider using `float`, `SupportsFloat` or some `Procotol`. https://github.com/pytorch/pytorch/pull/144197#discussion_r1903324769
# Notes
- `beta.py`: needed to add `type: ignore` since `broadcast_all` is untyped.
- `categorical.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`dirichlet.py`: replaced `axis` with `dim` arguments.~~ #144402
- `gemoetric.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`independent.py`: fixed bug in `Independent.__init__` where `tuple[int, ...]` could be passed to `Distribution.__init__` instead of `torch.Size`.~~ **EDIT:** turns out the bug is related to typing of `torch.Size`. #144218
- `independent.py`: made `Independent` a generic class of its base distribution.
- `multivariate_normal.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- `relaxed_bernoulli.py`: added class-level type hint for `base_dist`.
- `relaxed_categorical.py`: added class-level type hint for `base_dist`.
- ~~`transforms.py`: Added missing argument to docstring of `ReshapeTransform`~~ #144401
- ~~`transforms.py`: Fixed bug in `AffineTransform.sign` (could return `Tensor` instead of `int`).~~ #144400
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.log_abs_det_jacobian`[^1]; replaced `torch.abs(scale)` with `scale.abs()`.
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.__eq__`[^1].
- `transforms.py`: Fixed type hint on `CumulativeDistributionTransform.domain`. Note that this is still an LSP violation, because `Transform.domain` is defined as `Constraint`, but `Distribution.domain` is defined as `Optional[Constraint]`.
- skipped: `constraints.py`, `constraints_registry.py`, `kl.py`, `utils.py`, `exp_family.py`, `__init__.py`.
## Remark
`TransformedDistribution`: `__init__` uses the check `if reinterpreted_batch_ndims > 0:`, which can lead to the creation of `Independent` distributions with only 1 component. This results in awkward code like `base_dist.base_dist` in `LogisticNormal`.
```python
import torch
from torch.distributions import *
b1 = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
b2 = MultivariateNormal(torch.tensor([0.0]), torch.eye(1))
t = StickBreakingTransform()
d1 = TransformedDistribution(b1, t)
d2 = TransformedDistribution(b2, t)
print(d1.base_dist) # Independent with 1 dimension
print(d2.base_dist) # MultivariateNormal
```
One could consider changing this to `if reinterpreted_batch_ndims > 1:`.
[^1]: Usage of `isinstance(value, numbers.Real)` leads to problems with static typing, as the `numbers` module is not supported by `mypy` (see <https://github.com/python/mypy/issues/3186>). This results in us having to add type-ignore comments in several places
[^2]: Otherwise, we would have to add a bunch of `type: ignore` comments to make `mypy` happy, as it isn't able to perform the type narrowing. Ideally, such code should be replaced with structural pattern matching once support for Python 3.9 is dropped.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144197
Approved by: https://github.com/malfet
Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
Fixes#76772, #144196
Extends #144106
- added type annotations to `lazy_property`.
- added type annotation to all `@property` and `@lazy_property` inside `torch.distributions` module.
- added simply type-check unit test to ensure type inference is working.
- replaced deprecated annotations like `typing.List` with the corresponding counterpart.
- simplified `torch.Tensor` hints with plain `Tensor`, otherwise signatures can become very verbose.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144110
Approved by: https://github.com/Skylion007
This PR comprises a few small contributions:
1. `PowerTransform` returned a sign of `+1` irrespective of exponent. However, it should return the sign of the exponent because the gradient has the same sign as the exponent. That issue has been fixed.
2. Added tests to catch errors akin to 1. in the future.
3. Added an `InverseGamma` distribution as a `TransformedDistribution` with `PowerTransform(-1)` and `Gamma` base distribution. The `InverseGamma` is often used as a prior for the length scale of Gaussian processes to aggressively suppress short length scales (see [here](https://betanalpha.github.io/assets/case_studies/gaussian_processes.html#323_Informative_Prior_Model) for a discussion).
Note: I added a `positive` constraint for the support of the inverse gamma distribution because the `PowerTransform(-1)` can fail for `nonnegative` constraints if the random variable is zero.
```python
>>> torch.distributions.InverseGamma(0.5, 1.0).log_prob(torch.zeros(1))
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-8-758aa22deacd> in <module>
----> 1 torch.distributions.InverseGamma(0.5, 1.0).log_prob(torch.zeros(1))
~/git/pytorch/torch/distributions/transformed_distribution.py in log_prob(self, value)
140 """
141 if self._validate_args:
--> 142 self._validate_sample(value)
143 event_dim = len(self.event_shape)
144 log_prob = 0.0
~/git/pytorch/torch/distributions/distribution.py in _validate_sample(self, value)
298 valid = support.check(value)
299 if not valid.all():
--> 300 raise ValueError(
301 "Expected value argument "
302 f"({type(value).__name__} of shape {tuple(value.shape)}) "
ValueError: Expected value argument (Tensor of shape (1,)) to be within the support (GreaterThan(lower_bound=0.0)) of the distribution InverseGamma(), but found invalid values:
tensor([0.])
```
This differs from the scipy implementation.
```python
>>> scipy.stats.invgamma(0.5).pdf(0)
0.0
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104501
Approved by: https://github.com/fritzo, https://github.com/ezyang
`log1p` offers better precision near zero since `(1 + x) - 1` truncates any
values less than the float epsilon to zero. For `soft_margin_loss` this also
requires one fewer kernel invocation which for numel=1e7 gives me a 1.2x speedup
on CUDA and a 1.1x speedup on CPU.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92114
Approved by: https://github.com/ngimel, https://github.com/lezcano
The `PositiveDefiniteTransform` is required to transform from an unconstrained space to positive definite matrices, e.g. to support testing the Wishart mode in #76690. It is a simple extension of the `LowerCholeskyTransform`.
I've also added a small test that ensures the generated data belong to the domain of the associated transform. Previously, the data generated for the inverse transform of the `LowerCholeskyTransform` wasn't part of the domain, and the test only passed because the comparison uses `equal_nan=True`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76777
Approved by: https://github.com/lezcano, https://github.com/fritzo, https://github.com/soumith
`Transform` is not currently pickleable if the inverse transform cache `_inv` is not `None` because `_inv` is a `weakref` which cannot be serialized by `pickle`.
The following succeeds.
```python
>>> import torch as th
>>> import pickle
>>> dist = th.distributions.TransformedDistribution(
... th.distributions.Normal(0, 1),
... [th.distributions.AffineTransform(2, 3)]
... )
>>> th.save(dist, "some-file.pt")
```
But the transformed distribution can no longer be pickled after evaluating `log_prob` (which implicitly creates `_inv`).
```python
>>> dist.log_prob(th.linspace(0, 1, 10))
>>> th.save(dist, "some-file.pt")
TypeError: cannot pickle 'weakref' object
```
This PR fixes the issue by setting `_inv` to `None` in `__getstate__`. cc @fritzo, @neerajprad
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81707
Approved by: https://github.com/fritzo
Summary:
Fixes hiding of the `CatTransform` docstring due to a misplaced type annotation in https://github.com/pytorch/pytorch/issues/45689.
Also fixes that annotation and adds to to `StackTransform` to keep `CatTransform` similar.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73747
Reviewed By: mruberry
Differential Revision: D34649927
Pulled By: neerajprad
fbshipit-source-id: e4594fd76020a37ac4cada88e5fb08191984e911
(cherry picked from commit cec256c3242d1cf55073a980060af87c1fd59ac9)
Summary:
This pull request introduces `SoftplusTransform` to `torch.distributions.transforms`. `SoftplusTransform` transforms via the mapping `Softplus(x) = log(1 + exp(x))`. Note that the transform is different to [`torch.nn.Softplus`](https://pytorch.org/docs/stable/generated/torch.nn.Softplus.html#torch.nn.Softplus), as that has additional `beta` and `threshold` parameters. Inverse and `log_abs_det_jacobian` for a more complex `SoftplusTransform` can be added in the future.
vitkl fritzo
Addresses the issue discussed here: [pyro issue 855](https://github.com/pyro-ppl/numpyro/issues/855)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/52300
Reviewed By: albanD, ejguan
Differential Revision: D34082655
Pulled By: neerajprad
fbshipit-source-id: 6114e74ee5d73c1527191bed612a142d691e2094
(cherry picked from commit a181a3a9e53a34214a503d38760ad7778d08a680)
Summary:
This PR adds a transform that uses the cumulative distribution function of a given probability distribution.
For example, the following code constructs a simple Gaussian copula.
```python
# Construct a Gaussian copula from a multivariate normal.
base_dist = MultivariateNormal(
loc=torch.zeros(2),
scale_tril=LKJCholesky(2).sample(),
)
transform = CumulativeDistributionTransform(Normal(0, 1))
copula = TransformedDistribution(base_dist, [transform])
```
The following snippet creates a "wrapped" Gaussian copula for correlated positive variables with Weibull marginals.
```python
transforms = [
CumulativeDistributionTransform(Normal(0, 1)),
CumulativeDistributionTransform(Weibull(4, 2)).inv,
]
wrapped_copula = TransformedDistribution(base_dist, transforms)
```
cc fritzo
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72495
Reviewed By: ejguan
Differential Revision: D34085919
Pulled By: albanD
fbshipit-source-id: 7917391519a96b0d9b54c52db65d1932f961d070
(cherry picked from commit 572196146e)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/50496
Fixes https://github.com/pytorch/pytorch/issues/34859
Fixes https://github.com/pytorch/pytorch/issues/21596
This fixes many bugs involving `TransformedDistribution` and `ComposeTransform` when the component transforms changed their event shapes. Part of the fix is to introduce an `IndependentTransform` analogous to `distributions.Independent` and `constraints.independent`, and to introduce methods `Transform.forward_shape()` and `.inverse_shape()`. I have followed fehiepsi's suggestion and replaced `.input_event_dim` -> `.domain.event_dim` and `.output_event_dim` -> `.codomain.event_dim`. This allows us to deprecate `.event_dim` as an attribute.
## Summary of changes
- Fixes `TransformDistribution` and `ComposeTransform` shape errors.
- Fixes a behavior bug in `LogisticNormal`.
- Fixes `kl_divergence(TransformedDistribution, TransformedDistribution)`
- Adds methods `Transform.forward_shape()`, `.inverse_shape()` which are required for correct shape computations in `TransformedDistribution` and `ComposeTransform`.
- Adds an `IndependentTransform`.
- Adds a `ReshapeTransform` which is invaluable in testing shape logic in `ComposeTransform` and `TransformedDistribution` and which will be used by stefanwebb flowtorch.
- Fixes incorrect default values in `constraints.dependent.event_dim`.
- Documents the `.event_dim` and `.is_discrete` attributes.
## Changes planned for follow-up PRs
- Memoize `constraints.dependent_property` as we do with `lazy_property`, since we now consult those properties much more often.
## Tested
- [x] added a test for `Dist.support` vs `Dist(**params).support` to ensure static and dynamic attributes agree.
- [x] refactoring is covered by existing tests
- [x] add test cases for `ReshapedTransform`
- [x] add a test for `TransformedDistribution` on a wide grid of input shapes
- [x] added a regression test for https://github.com/pytorch/pytorch/issues/34859
cc fehiepsi feynmanliang stefanwebb
Pull Request resolved: https://github.com/pytorch/pytorch/pull/50581
Reviewed By: ezyang, glaringlee, jpchen
Differential Revision: D26024247
Pulled By: neerajprad
fbshipit-source-id: f0b9a296f780ff49659b132409e11a29985dde9b
Summary:
Fixes https://github.com/pytorch/pytorch/issues/44530
As explained in the issue description, CatTransform does not work with event_dim > 0.
This PR fixes this. If this gets approved I am hoping to do the same for StackTransform as well.
fritzo Can you take a look at this ?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49111
Reviewed By: neerajprad
Differential Revision: D25526005
Pulled By: ezyang
fbshipit-source-id: e14430093f550d5e0da7a311f9cd44796807830f
Summary:
This adds a transform to convert a real vector of (D * (D-1))/2 dimension into the cholesky factor of a D x D correlation matrix. This follows the implementation in [NumPyro](https://github.com/pyro-ppl/numpyro/blob/master/numpyro/distributions/transforms.py) by fehiepsi. This is needed for the LKJDistribution which will be added in a subsequent PR.
Also in line with the ongoing effort to refactor distributions test, this moves the transforms test into its own file that uses pytest with parametrized fixtures.
For review:
fehiepsi - could you help review the math?
fritzo - do you have any suggestions for what to do about the event dimension (more details are in the comment below)?
ezyang - could you review the changes in `run_test.py`? Instead of a separate `PYTEST_TESTS`, I have clubbed these tests in `USE_PYTEST_LIST` to avoid duplicate logic. The only difference is that we do not anymore check if pytest is not installed and exclude the tests in the list. I figured that if existing tests are already using pytest, this should not matter.
TODOs (probably not all can be satisfied at the same time):
- [x] Use operations that are JIT friendly, i.e. the transform works with different sized input under JIT.
- [x] Resolve test failures - currently `arange(scalar_tensor)` fails on certain backends but this is needed for JIT. Maybe we should only support same sized tensor under JIT?
- [x] Add tests to check that the transform gives correct gradients and is in agreement with the `log_det_jacobian`.
- [x] Add `input_event_dim` and `output_event_dim` to `CorrCholeskyTransform`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48041
Reviewed By: zhangguanheng66
Differential Revision: D25262505
Pulled By: neerajprad
fbshipit-source-id: 5a57e1c19d8230b53592437590b9169bdf2f71e9
Summary:
This resolves an issue observed by stefanwebb where the composition of multiple transforms is cached only if all components are cached.
This PR adds a new method `.with_cache()` so that e.g. you can compose a normalizing flow (that needs to be cached) with a `SigmoidTransform` (that wasn't already cached) by calling `.with_cache()` on the latter. This issue also comes up when composing non-cached constraint transforms as returned by `transform_to()` and `biject_to()`: after this PR you can call `transform_to(constraints.positive).with_cache()` to get a cached `ExpTransform`.
## Tested
- [x] added a unit test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36882
Differential Revision: D21155914
Pulled By: ezyang
fbshipit-source-id: 3c06e63785ca2503e08a5cd7532aff81882835e9
Summary:
Removes older `torch.stack`-based logic in favor of `torch.diagonal()` and `torch.diag_embed()`.
I see 100x speedup in my application, where my batched matrix has shape `(800, 32 ,32)`.
```py
import torch
from torch.distributions import constraints, transform_to
x = torch.randn(800, 32, 32, requires_grad=True)
# Before this PR:
%%timeit
transform_to(constraints.lower_cholesky)(x).sum().backward()
# 579 ms ± 34.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# After this PR:
%%timeit
transform_to(constraints.lower_cholesky)(x).sum().backward()
# 4.5 ms ± 241 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/24131
Differential Revision: D16764035
Pulled By: ezyang
fbshipit-source-id: 170cdb0d924cdc94cd5ad3b75d1427404718d437
Summary:
This modernizes distributions code by replacing a few uses of `.contiguous().view()` with `.reshape()`, fixing a sample bug in the `Categorical` distribution.
The bug is exercised by the following test:
```py
batch_shape = (1, 2, 1, 3, 1)
sample_shape = (4,)
cardinality = 2
logits = torch.randn(batch_shape + (cardinality,))
dist.Categorical(logits=logits).sample(sample_shape)
# RuntimeError: invalid argument 2: view size is not compatible with
# input tensor's size and stride (at least one dimension spans across
# two contiguous subspaces). Call .contiguous() before .view().
# at ../aten/src/TH/generic/THTensor.cpp:203
```
I have verified this works locally, but I have not added this as a regression test because it is unlikely to regress (the code is now simpler).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23328
Differential Revision: D16510678
Pulled By: colesbury
fbshipit-source-id: c125c1a37d21d185132e8e8b65241c86ad8ad04b
Summary:
This PR addresses some numerical issues of Sigmoid/StickBreakingTransform, where these transforms give +-inf when the unconstrained values move to +-20 areas.
For example, with
```
t = torch.distributions.SigmoidTransform()
x = torch.tensor(20.)
t.inv(t(x)), t.log_abs_det_jacobian(x, t(x))
```
current behaviour the inverse will return `inf` and logdet return `-inf` while this PR makes it to `15.9424` and `-15.9424`.
And for
```
t = torch.distributions.StickBreakingTransform()
x = torch.tensor([20., 20.])
t.inv(t(x)), t.log_abs_det_jacobian(x, t(x))
```
current value is `(inf, nan)` and `-inf` for logdet, while this PR makes it `[16.6355, 71.3942]` and `-47.8272` for logdet.
Although these finite values are wrong and seems unavoidable, it is better than returning `inf` or `nan` in my opinion. This is useful in HMC where despite that the grad will be zero when the unconstrained parameter moves to unstable area (due to clipping), velocity variable will force the parameter move to another area which by chance can move the parameter out of unstable area. But inf/nan can be useful to stop doing inference early. So the changes in this PR might be inappropriate.
I also fix some small issues of `_Simplex` and `_RealVector` constraints where batch shape of the input is not respected when checking validation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20288
Differential Revision: D15742047
Pulled By: ezyang
fbshipit-source-id: b427ed1752c41327abb3957f98d4b289307a7d17
Summary:
This adds calls to `super().__init__()` in three classes in torch.distributions.
This is needed when `Distribution` and `Transform` objects are used with multiple inheritance, as e.g. combined with `torch.nn.Module`s. For example
```py
class MyModule(torch.distributions.Transform, torch.nn.Module):
...
```
cc martinjankowiak esling who have wanted to use this pattern, e.g. in #16756
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16772
Differential Revision: D13978633
Pulled By: soumith
fbshipit-source-id: 8bc6cca1747cd74d32135ee2fe588bba2ea796f1
Summary:
The earlier tests had around 80 warnings, and now there are 6 warnings: these are due to JIT
The changes remove the wrapping of a Tensor by a Tensor constructor, which emits warnings due to the changes in https://github.com/pytorch/pytorch/pull/11061 .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12038
Differential Revision: D10033392
Pulled By: apaszke
fbshipit-source-id: b1faf368e650d062d7983f9932511bee4702a893
* Codemod to update our codebase to 0.4 standard
* Update some of the test scri[ts
* remove Variable in test_clip_grad_value
* fix _symbolic_override_wrapper_maker
This avoids promotion from python float to torch.Tensor for AffineTransform. This appears to be needed so that constraint registration works across CPU and all GPUs.
Previous discussion at 3a25db73c8 (r176361909)
Background:
There are three basic types of objects in torch.distributions:
- Distributions are flyweight objects constructed from tensor or float args. They always promote float args to tensors.
- Transforms are longer-lived objects (sometimes cached; some are static globals). They can take float arguments. This PR makes AffineTransform avoid promoting float args to tensors.
- Constraints are long-lived objects. They can take either float or tensor arguments. They do not promote floats to tensors. These are relatively symbolic and are not much more than partially evaluated comparisons, e.g. constraints.positive is basically a symbolic version of lambda x: x > 0 that can be stored in a ConstraintRegistry table.
The Problem:
Sometimes we want to apply transform_to(constraints.positive) to a torch.Cuda.FloatTensor. This is fine since
transform_to(constraints.positive)(x)
= ExpTransform()(x)
= x.exp()
which works with any tensor type.
Other times we want to apply transform_to(constraints.greater_than(1.5)) to a torch.cuda.FloatTensor. This is problematic before this PR since
transform_to(constraints.greater_than(1.5))(x)
= ComposeTransform([ExpTransform(), AffineTransform(1.5, 1)])(x)
= AffineTransform(1.5, 1)(x.exp())
= t.loc + t.scale * x.exp() # where t = AffineTransform(1.5, 1)
Before this PR, AffineTransform would promote t.loc and t.scale to tensors. This promotion can happen as early as library load time for some transforms, e.g. transform_to(constraints.unit_interval). Therefore before this PR, the second example would error at t.scale * x.exp() because t.scale is a [default] torch.FloatTensor whereas x.exp() is a torch.cuda.FloatTensor.
Proposed solution:
This PR merely adds support for python floats as the .loc and .scale parameters of AffineTransform. This should suffice for most purposes since only AffineTransform and a handful of parameter-free transforms are ever stored in the global transform_to and biject_to registries.
Alternative solutions include:
- allowing promotion from torch.FloatTensor to all other tensor types, e.g. torch.cuda.FloatTensor.
- adding a handful of specific parameter-free transforms like NegateTransform() in lieu of AffineTransform(0, -1).
Tested: added a regression test
* Support python floats in AffineTransform
* Update docstrings