Summary:
This PR addresses some numerical issues of Sigmoid/StickBreakingTransform, where these transforms give +-inf when the unconstrained values move to +-20 areas.
For example, with
```
t = torch.distributions.SigmoidTransform()
x = torch.tensor(20.)
t.inv(t(x)), t.log_abs_det_jacobian(x, t(x))
```
current behaviour the inverse will return `inf` and logdet return `-inf` while this PR makes it to `15.9424` and `-15.9424`.
And for
```
t = torch.distributions.StickBreakingTransform()
x = torch.tensor([20., 20.])
t.inv(t(x)), t.log_abs_det_jacobian(x, t(x))
```
current value is `(inf, nan)` and `-inf` for logdet, while this PR makes it `[16.6355, 71.3942]` and `-47.8272` for logdet.
Although these finite values are wrong and seems unavoidable, it is better than returning `inf` or `nan` in my opinion. This is useful in HMC where despite that the grad will be zero when the unconstrained parameter moves to unstable area (due to clipping), velocity variable will force the parameter move to another area which by chance can move the parameter out of unstable area. But inf/nan can be useful to stop doing inference early. So the changes in this PR might be inappropriate.
I also fix some small issues of `_Simplex` and `_RealVector` constraints where batch shape of the input is not respected when checking validation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20288
Differential Revision: D15742047
Pulled By: ezyang
fbshipit-source-id: b427ed1752c41327abb3957f98d4b289307a7d17
Summary:
Currently, when the input of MVN is precision matrix, we take inverse to convert the result to covariance matrix. This, however, will easily make the covariance matrix not positive definite, hence will trigger a cholesky error.
For example,
```
import torch
torch.manual_seed(0)
x = torch.randn(10)
P = torch.exp(-(x - x.unsqueeze(-1)) ** 2)
torch.distributions.MultivariateNormal(loc=torch.ones(10), precision_matrix=P)
```
will trigger `RuntimeError: cholesky_cpu: U(8,8) is zero, singular U.`
This PR uses some math tricks ([ref](https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril)) to only take inverse of a triangular matrix, hence increase the stability.
cc fritzo, neerajprad , SsnL
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21366
Differential Revision: D15696972
Pulled By: ezyang
fbshipit-source-id: cec13f7dfdbd06dee94b8bed8ff0b3e720c7a188
Summary:
Unit tests that hang on clock64() calls are now fixed.
test_gamma_gpu_sample is now fixed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19307
Differential Revision: D14953420
Pulled By: bddppq
fbshipit-source-id: efe807b54e047578415eb1b1e03f8ad44ea27c13
Summary:
The derivative of the Cholesky decomposition was previously a triangular matrix.
Changelog:
- Modify the derivative of Cholesky from a triangular matrix to symmetric matrix
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19116
Differential Revision: D14935470
Pulled By: ezyang
fbshipit-source-id: 1c1c76b478c6b99e4e16624682842cb632e8e8b9
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18291
ghimport-source-id: d6e95e899bd320407967df41435801e54864ba62
Stack from [ghstack](https://github.com/ezyang/ghstack):
* #18292 Add test for #17271 (torch.exp incorrect for 2**31 size tensor)
* **#18291 Correctly call superclass setUp in TestCase subclasses.**
This makes PYTORCH_TEST_SKIP_FAST work correctly for more
tests, reducing the wasted testing effort on our slow_test job.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision: D14567643
fbshipit-source-id: 40cf1d6556e0dd0a0550ff3d9ffed8b6000f8191
Summary:
Addresses #15738, using fritzo's suggestion. This adds a `torch._sample_dirichlet` method in `Distributions.cpp` and `Distributions.cu`.
- For CPU, this leads to no perf hit since all we do is to promote the `alpha` to double when getting the gamma samples (the gamma sampler anyways uses `accscalar_t`(double for CPU)) and cast it back to float32 on return.
- I have added an analogous method for CUDA as well, but the default sampler for CUDA uses scalar_t for efficiency, so I have kept it as that. With this, I do not see the bias towards 1 as reported in #15738 with `float32`, but there is a spurious mode at 0.5, as would be expected. Users would need to explicitly use `float64` for GPU to not see the spurious mode at 0.5. (EDIT: see note below, it appears that the bias issue is still there for certain builds).
Added some tests and checked that there is no perf regression. My experience with C++ is very limited, so apologies in advance if I missed something basic. cc. ailzhang, fritzo, fmassa
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17488
Differential Revision: D14410301
Pulled By: ezyang
fbshipit-source-id: 62b2f694b4642685eab06db96d74ce28e05c3992
Summary:
This is the first round of enabling unit tests that work on ROCm 2.1 in my tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16871
Differential Revision: D13997662
Pulled By: bddppq
fbshipit-source-id: d909a3f7dd5fc8f85f126bf0613751c8e4ef949f
Summary:
* we do not need EAP packages any longer as the antistatic feature is now in the release
* consistently install the rccl package
* Skip one unit test that has regressed with 2.1
* Follow-up PRs will use 2.1 features once deployed on CI
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16808
Differential Revision: D13992645
Pulled By: bddppq
fbshipit-source-id: 37ca9a1f104bb140bd2b56d403e32f04c4fbf4f0
Summary:
This PR removes the usage of _finfo defined in torch.distributions.utils and changes the call sites
to use torch.finfo instead
Differential Revision: D13451936
Pulled By: soumith
fbshipit-source-id: 6dbda3a6179d9407bc3396bf1a2baf3e85bc4cf2
Summary:
This enables the distributions and utils test sets for ROCm.
Individual tests are enabled that now pass due to fixes in HIP/HCC/libraries versions in white rabbit.
For attention: bddppq ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13166
Differential Revision: D12814759
Pulled By: bddppq
fbshipit-source-id: ea70e775c707d7a8d2776fede6154a755adef43e
Summary:
- fixes weights-contiguous requirement for THCUNN Convolutions
- Add tests that conv backward pass works for non-contiguous weights
- fix RNN tests / error messages to be consistent and pass
- relax weight grad precision for fp16 for a particular test
- fix regression of CMAKE_PREFIX_PATH not passing through
- add missing skipIfNoLapack annotations where needed
Differential Revision: D12918456
Pulled By: soumith
fbshipit-source-id: 8642d36bffcc6f2957800d6afa1e10bef2a91d05
Summary:
This PR performs a renaming of the function `potrf` responsible for the Cholesky
decomposition on positive definite matrices to `cholesky` as NumPy and TF do.
Billing of changes
- make potrf cname for cholesky in Declarations.cwrap
- modify the function names in ATen/core
- modify the function names in Python frontend
- issue warnings when potrf is called to notify users of the change
Reviewed By: soumith
Differential Revision: D10528361
Pulled By: zou3519
fbshipit-source-id: 19d9bcf8ffb38def698ae5acf30743884dda0d88
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12794
common.py is used in base_module for almost all tests in test/. The
name of this file is so common that can easily conflict with other dependencies
if they happen to have another common.py in the base module. Rename the file to
avoid conflict.
Reviewed By: orionr
Differential Revision: D10438204
fbshipit-source-id: 6a996c14980722330be0a9fd3a54c20af4b3d380
Summary:
This fixes a broadcasting error with the `StudentT` distribution
- [x] added a regression test
- [x] strengthened parameter broadcasting tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12148
Differential Revision: D10099226
Pulled By: soumith
fbshipit-source-id: 0c5eb14180d158f8fff28ceb9e7cd3471c2bb803
Summary:
The earlier tests had around 80 warnings, and now there are 6 warnings: these are due to JIT
The changes remove the wrapping of a Tensor by a Tensor constructor, which emits warnings due to the changes in https://github.com/pytorch/pytorch/pull/11061 .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12038
Differential Revision: D10033392
Pulled By: apaszke
fbshipit-source-id: b1faf368e650d062d7983f9932511bee4702a893
Summary:
This PR:
- adds a `.expand` method for `TransformedDistribution` along the lines of #11341.
- uses this method to simplify `.expand` in distribution classes that subclass off of `TransformedDistribution`.
- restores testing of `TransformedDistribution` fixtures.
- fixes some bugs wherein we were not setting certain attributes in the expanded instances, and adds tests for `.mean` and `.variance` which use these attributes.
There are many cases where users directly use `TransformedDistribution` rather than subclassing off it. In such cases, it seems rather inconvenient to have to write a separate class just to define a `.expand` method. The default implementation should suffice in these cases.
cc. fritzo, vishwakftw, alicanb
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11607
Differential Revision: D9818225
Pulled By: soumith
fbshipit-source-id: 2c4b3812b9a03e6985278cfce0f9a127ce536f23
Summary:
This adds tests in tests/test_distributions.py to ensure that all methods of `Distribution` objects are jittable.
I've replaced a few samplers with jittable versions:
- `.uniform_()` -> `torch.rand()`
- `.exponential_()` -> `-(-torch.rand()).log1p()`
- `.normal_()` -> `torch.normal(torch.zeros(...), torch.ones(...), ...)`
Some jit failures remain, and are marked in test_distributions.py
- `Cauchy` and `HalfCauchy` do not support sampling due to missing `.cauchy_()`
- `Binomial` does not support `.enumerate_support()` due to `arange` ignoring its first arg.
- `MultivariateNormal`, `LowRankMultivariateNormal` do not support `.mean`, `.entropy`
- [x] Currently some tests fail (I've skipped those) due to unavailability of `aten::uniform` and `aten::cauchy` in the jit. Can someone suggest how to add these? I tried to add declarations to `torch/csrc/ir.cpp` and `torch/csrc/passes/shape_analysis.cpp`, but that resulted in "Couldn't find operator" errors.
- [x] There are still lots of `TracerWarning`s that something doesn't match something. I'm not sure whether these are real.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11560
Differential Revision: D9816327
Pulled By: apaszke
fbshipit-source-id: 72ec998ea13fc4c76d1ed003d9502e0fbaf728b8
Summary:
This adds a `.expand` method for distributions that is akin to the `torch.Tensor.expand` method for tensors. It returns a new distribution instance with batch dimensions expanded to the desired `batch_shape`. Since this calls `torch.Tensor.expand` on the distribution's parameters, it does not allocate new memory for the expanded distribution instance's parameters.
e.g.
```python
>>> d = dist.Normal(torch.zeros(100, 1), torch.ones(100, 1))
>>> d.sample().shape
torch.Size([100, 1])
>>> d.expand([100, 10]).sample().shape
torch.Size([100, 10])
```
We have already been using the `.expand` method in Pyro in our [patch](https://github.com/uber/pyro/blob/dev/pyro/distributions/torch.py#L10) of `torch.distributions`. We use this in our models to enable dynamic broadcasting. This has also been requested by a few users on the distributions slack, and we believe will be useful to the larger community.
Note that currently, there is no convenient and efficient way to expand distribution instances:
- Many distributions use `TransformedDistribution` (or wrap over another distribution instance. e.g. `OneHotCategorical` uses a `Categorical` instance) under the hood, or have lazy parameters. This makes it difficult to collect all the relevant parameters, broadcast them and construct new instances.
- In the few cases where this is even possible, the resulting implementation would be inefficient since we will go through a lot of broadcasting and args validation logic in `__init__.py` that can be avoided.
The `.expand` method allows for a safe and efficient way to expand distribution instances. Additionally, this bypasses `__init__.py` (using `__new__` and populating relevant attributes) since we do not need to do any broadcasting or args validation (which was already done when the instance was first created). This can result in significant savings as compared to constructing new instances via `__init__` (that said, the `sample` and `log_prob` methods will probably be the rate determining steps in many applications).
e.g.
```python
>>> a = dist.Bernoulli(torch.ones([10000, 1]), validate_args=True)
>>> %timeit a.expand([10000, 100])
15.2 µs ± 224 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
>>> %timeit dist.Bernoulli(torch.ones([10000, 100]), validate_args=True)
11.8 ms ± 153 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
```
cc. fritzo, apaszke, vishwakftw, alicanb
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11341
Differential Revision: D9728485
Pulled By: soumith
fbshipit-source-id: 3b94c23bc6a43ee704389e6287aa83d1e278d52f
Summary:
This adds an optional `expand=True` kwarg to the `distribution.expand_support()` method, to get a distribution's support without expanding the values over the distribution's `batch_shape`.
- The default `expand=True` preserves the current behavior, whereas `expand=False` collapses the batch dimensions.
e.g.
```python
In [47]: d = dist.OneHotCategorical(torch.ones(3, 5) * 0.5)
In [48]: d.batch_shape
Out[48]: torch.Size([3])
In [49]: d.enumerate_support()
Out[49]:
tensor([[[1., 0., 0., 0., 0.],
[1., 0., 0., 0., 0.],
[1., 0., 0., 0., 0.]],
[[0., 1., 0., 0., 0.],
[0., 1., 0., 0., 0.],
[0., 1., 0., 0., 0.]],
[[0., 0., 1., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 0., 1., 0., 0.]],
[[0., 0., 0., 1., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 0., 1., 0.]],
[[0., 0., 0., 0., 1.],
[0., 0., 0., 0., 1.],
[0., 0., 0., 0., 1.]]])
In [50]: d.enumerate_support().shape
Out[50]: torch.Size([5, 3, 5])
In [51]: d.enumerate_support(expand=False)
Out[51]:
tensor([[[1., 0., 0., 0., 0.]],
[[0., 1., 0., 0., 0.]],
[[0., 0., 1., 0., 0.]],
[[0., 0., 0., 1., 0.]],
[[0., 0., 0., 0., 1.]]])
In [52]: d.enumerate_support(expand=False).shape
Out[52]: torch.Size([5, 1, 5])
```
**Motivation:**
- Currently `enumerate_support` builds up tensors of size `support + batch_shape + event_shape`, but the values are *repeated* over the `batch_shape` (adding little in the way of information). This can lead to expensive matrix operations over large tensors when `batch_shape` is large (see, example above), often leading to OOM issues. We use `expand=False` in Pyro for message passing inference. e.g. when enumerating over the state space in a Hidden Markov Model. This creates sparse tensors that capture the markov dependence, and allows for the possibility of using optimized matrix operations over these sparse tensors. `expand=True`, on the other hand, will create tensors that scale exponentially in size with the length of the Markov chain.
- We have been using this in our [patch](https://github.com/uber/pyro/blob/dev/pyro/distributions/torch.py) of `torch.distributions` in Pyro. The interface has been stable, and it is already being used in a few Pyro algorithms. We think that this is more broadly applicable and will be of interest to the larger distributions community.
cc. apaszke, fritzo, alicanb
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11231
Differential Revision: D9696290
Pulled By: soumith
fbshipit-source-id: c556f8ff374092e8366897ebe3f3b349538d9318
Summary:
`__repr__` currently fails for distributions with lazy attributes in PyTorch master, throwing a `KeyError`. This fixes the issue.
**Additionally:**
- Added `logits` to `arg_constraints` for distributions that accept either `probs` or `logits`. This is both to have `__repr__` display the `logits` param when available, and to be able to do validation checks (e.g. NaN checks) when the logit parametrization is used. fritzo, alicanb - I think there were reasons why we had not done so in the first place, but I am unable to recall now. It passes all the tests, but let me know if there is something that I am missing at the moment.
- There are certain distributions, e.g. `OneHotCategorical` which won't show any parameters because it uses a `categorical` instance under the hood and neither `logits` / `probs` in `arg_constraints` are present in the instance's `__dict__`. This isn't addressed in this PR.
cc. vishwakftw, fritzo, nadavbh12, apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11263
Differential Revision: D9654959
Pulled By: apaszke
fbshipit-source-id: 16f5b20243fe8e2c13e9c528050d4df0b8ea6e45
Summary:
Resubmit #10416 with fixed tests . This is to remove implicit conversion from gpu to cpu in when calling numpy to keep behavior match others.
It requires users to move the tensor back to cpu() before call numpy functions on it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10553
Differential Revision: D9350212
Pulled By: ailzhang
fbshipit-source-id: 9317d8fea925d4b20ae3150e2c1b39ba5c9c9d0a
Summary:
Support broadcasting in _kl_categorical_categorical
this makes it possible to do:
```
import torch.distributions as dist
import torch
p_dist = dist.Categorical(torch.ones(1,10))
q_dist = dist.Categorical(torch.ones(100,10))
dist.kl_divergence(p_dist, q_dist)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10533
Differential Revision: D9341252
Pulled By: soumith
fbshipit-source-id: 34575b30160b43b6c9e4c3070dd7ef07c00ff5d7
Summary:
This causes numpy to yield to the torch functions,
e.g. instead of numpy array/scalar __mul__ converting the tensor to
an array, it will now arrange for the Tensor __rmul__ to be called.
Fixes case 2 of #9468
I also makes case 3 and 4 equivalent but does not fix them.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9651
Differential Revision: D8948079
Pulled By: ezyang
fbshipit-source-id: bd42c04e96783da0bd340f37f4ac3559e9bbf8db
Summary:
This implements the two-parameter Weibull distribution, with scale $\lambda$ and shape $k$ parameters as described on [Wikipedia](https://en.wikipedia.org/wiki/Weibull_distribution).
**Details**
- We implement as a transformed exponential distribution, as described [here](https://en.wikipedia.org/wiki/Weibull_distribution#Related_distributions).
- The `weibull_min` variance function in scipy does not yet support a vector of distributions, so our unit test uses a scalar distribution instead of a vector.
Example of the bug:
```
>>> sp.stats.expon(np.array([0.5, 1, 2])).var() # fine
array([1., 1., 1.])
>>> sp.stats.weibull_min(c=np.array([0.5, 1, 2])).var() # buggy
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/lib/python3.7/site-packages/scipy/stats/_distn_infrastructure.py", line 490, in var
return self.dist.var(*self.args, **self.kwds)
File "/usr/local/lib/python3.7/site-packages/scipy/stats/_distn_infrastructure.py", line 1242, in var
res = self.stats(*args, **kwds)
File "/usr/local/lib/python3.7/site-packages/scipy/stats/_distn_infrastructure.py", line 1038, in stats
if np.isinf(mu):
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9454
Differential Revision: D8863574
Pulled By: SsnL
fbshipit-source-id: 1ad3e175b469eee2b6af98e7b379ea170d3d9787
Summary:
This pull request implements low rank multivariate normal distribution where the covariance matrix has the from `W @ W.T + D`. Here D is a diagonal matrix, W has shape n x m where m << n. It used "matrix determinant lemma" and "Woodbury matrix identity" to save computational cost.
During the way, I also revise MultivariateNormal distribution a bit. Here are other changes:
+ `torch.trtrs` works with cuda tensor. So I tried to use it instead of `torch.inverse`.
+ Use `torch.matmul` instead of `torch.bmm` in `_batch_mv`. The former is faster and simpler.
+ Use `torch.diagonal` for `_batch_diag`
+ Reimplement `_batch_mahalanobis` based on `_batch_trtrs_lower`.
+ Use trtrs to compute term2 of KL.
+ `variance` relies on `scale_tril` instead of `covariance_matrix`
TODO:
- [x] Resolve the fail at `_gradcheck_log_prob`
- [x] Add test for KL
cc fritzo stepelu apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/8635
Differential Revision: D8951893
Pulled By: ezyang
fbshipit-source-id: 488ee3db6071150c33a1fb6624f3cfd9b52760c3
Summary:
cc vishwakftw
Also added a check if none of the input tensors in `gradcheck` have `requires_grad=True`.
Closes https://github.com/pytorch/pytorch/pull/9192
Differential Revision: D8739401
Pulled By: SsnL
fbshipit-source-id: 81bb3aa0b5c04eb209b137a4bd978e040e76cbcd
* Add memory leak check in CUDA tests
* Tracking multi-GPU too
* fix run_test.py not running __name__ == '__main__' content; add test for make_cuda_memory_checked_test
* add a comment
* skip if cuda
* 1. Change the wrapper to a method in common.py:TestCase
2. Refactor common constants/method that initialize CUDA context into common_cuda.py
3. Update some test files to use TEST_CUDA and TEST_MULTIGPU
* Fix MaxUnpool3d forward memory leak
* Fix MultiLabelMarginCriterion forward memory leak
* Fix MultiMarginLoss backward memory leak
* default doCUDAMemoryCheck to False
* make the wrapper skip-able
* use TEST_MULTIGPU
* add align_corners=True/False tests for Upsample; fix TEST_CUDNN
* finalize interface
* VolumetricMaxUnpooling_updateOutput
* fix test_nccl
* rename THC caching allocator methods to be clearer
* make the wrapped function a method
* address comments; revert changes to aten/src/THC/THCCachingAllocator.cpp
* fix renamed var
* Change backward calls to grad to avoid memory leak from #7343; Replace unnecesary create_graph=True with retain_graph=True
* fix gradgradcheck use of make_non_contiguous
* allow non-contguous target
* remove unnecessray .grad.zero_()
* remove contiguous_detach
* fix PReLU double backward always returning ggW as a scalar
* let noncontig gO require grad
* move requires_grad to return
* Don't allow requires_grad to be set on integer Tensor constructors in tensor_new.
* Fix autograd test.
* Fix test_distributions.
* Fix test_jit.
* Fix NN tests.
* fix for #7532: clamping the return value of uniform.cdf() to the range [0,1]
* removed whitespace around equals to pass flake8 tests
* added a test for uniform.cdf() with arguments outside support
* Refactor standard_gamma and implement CUDA gamma sampling
* Attempt fixes for AT_CUDA_ENABLED changes
* Gamma cuda and cpu forward as ATen native
* implement standard_gamma_grad_cuda
* update native_test.cpp, try to fix windows and various cuda version compiles
* searching a windows fix via CI... use std:: for math
* casting some constants in the calculation, compute at float for half precision
* whitespace fixes
* add acctype to do half->float computation, include HALF in generation, cast locally rather than tensors
* fix cuda8 half compilation
* always use scalar_cast with CUDACC, lock CPU generator, CPU acctype = double\nThank you for your review comments!