Commit Graph

54 Commits

Author SHA1 Message Date
Yuanyuan Chen
a60d9e1f6d Fix flake8 B028 warnings (#166224)
This PR fixes flake8 B028 warning by specifying stacklevel=2 in `warnings.warn`. The advantage is that users can know more contextual information about PyTorch warnings.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166224
Approved by: https://github.com/ezyang
2025-10-26 06:18:55 +00:00
Yuanhao Ji
b027cb8f9e [Docs] Add Description of validate_args for torch.distributions (#152173)
Fixes #152165

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152173
Approved by: https://github.com/soulitzer
2025-04-30 18:01:20 +00:00
Randolf Scholz
6c38b9be73 [typing] Add type hints to __init__ methods in torch.distributions. (#144197)
Fixes #144196
Extends #144106 and #144110

## Open Problems:

- [ ] Annotating with `numbers.Number` is a bad idea, should consider using `float`, `SupportsFloat` or some `Procotol`. https://github.com/pytorch/pytorch/pull/144197#discussion_r1903324769

# Notes

- `beta.py`: needed to add `type: ignore` since `broadcast_all` is untyped.
- `categorical.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`dirichlet.py`: replaced `axis` with `dim` arguments.~~ #144402
- `gemoetric.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`independent.py`: fixed bug in `Independent.__init__` where `tuple[int, ...]` could be passed to `Distribution.__init__` instead of `torch.Size`.~~ **EDIT:** turns out the bug is related to typing of `torch.Size`. #144218
- `independent.py`: made `Independent` a generic class of its base distribution.
- `multivariate_normal.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- `relaxed_bernoulli.py`: added class-level type hint for `base_dist`.
- `relaxed_categorical.py`: added class-level type hint for `base_dist`.
- ~~`transforms.py`: Added missing argument to docstring of `ReshapeTransform`~~ #144401
- ~~`transforms.py`: Fixed bug in `AffineTransform.sign` (could return `Tensor` instead of `int`).~~ #144400
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.log_abs_det_jacobian`[^1]; replaced `torch.abs(scale)` with `scale.abs()`.
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.__eq__`[^1].
- `transforms.py`: Fixed type hint on `CumulativeDistributionTransform.domain`. Note that this is still an LSP violation, because `Transform.domain` is defined as `Constraint`, but `Distribution.domain` is defined as `Optional[Constraint]`.
- skipped: `constraints.py`, `constraints_registry.py`, `kl.py`, `utils.py`, `exp_family.py`, `__init__.py`.

## Remark

`TransformedDistribution`: `__init__` uses the check `if reinterpreted_batch_ndims > 0:`, which can lead to the creation of `Independent` distributions with only 1 component. This results in awkward code like `base_dist.base_dist` in `LogisticNormal`.

```python
import torch
from torch.distributions import *
b1 = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
b2 = MultivariateNormal(torch.tensor([0.0]), torch.eye(1))
t = StickBreakingTransform()
d1 = TransformedDistribution(b1, t)
d2 = TransformedDistribution(b2, t)
print(d1.base_dist)  # Independent with 1 dimension
print(d2.base_dist)  # MultivariateNormal
```

One could consider changing this to `if reinterpreted_batch_ndims > 1:`.

[^1]: Usage of `isinstance(value, numbers.Real)` leads to problems with static typing, as the `numbers` module is not supported by `mypy` (see <https://github.com/python/mypy/issues/3186>). This results in us having to add type-ignore comments in several places
[^2]: Otherwise, we would have to add a bunch of `type: ignore` comments to make `mypy` happy, as it isn't able to perform the type narrowing. Ideally, such code should be replaced with structural pattern matching once support for Python 3.9 is dropped.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144197
Approved by: https://github.com/malfet

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
2025-04-06 17:50:35 +00:00
chilli
5e4cf3e6ad Moved .all() checks for distributions to _is_all_true (#145029)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145029
Approved by: https://github.com/Skylion007, https://github.com/zou3519
2025-01-18 07:55:48 +00:00
Randolf Scholz
355b0bc7e3 [typing] Add type hints to @property and @lazy_property in torch.distributions. (#144110)
Fixes #76772, #144196
Extends #144106

- added type annotations to `lazy_property`.
- added type annotation to all `@property` and `@lazy_property` inside `torch.distributions` module.
- added simply type-check unit test to ensure type inference is working.
- replaced deprecated annotations like `typing.List` with the corresponding counterpart.
- simplified `torch.Tensor` hints with plain `Tensor`, otherwise signatures can become very verbose.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144110
Approved by: https://github.com/Skylion007
2025-01-07 19:27:36 +00:00
Christopher Yeh
e72e924eb5 Add correct typing annotations to rsample() for all distributions (#133516)
Fixes #133514
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133516
Approved by: https://github.com/Skylion007
2024-08-18 20:31:54 +00:00
Xuehai Pan
b25ef91bf1 [BE][Easy][18/19] enforce style for empty lines in import segments in torch/d*/ (#129770)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129770
Approved by: https://github.com/wconstab
2024-08-01 04:22:50 +00:00
Aaron Orenstein
7c12cc7ce4 Flip default value for mypy disallow_untyped_defs [6/11] (#127843)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127843
Approved by: https://github.com/oulgen
ghstack dependencies: #127842
2024-06-08 18:49:29 +00:00
Xuehai Pan
67ef2683d9 [BE] wrap deprecated function/class with typing_extensions.deprecated (#127689)
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.

Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.

Resolves #126888

- #126888

This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127689
Approved by: https://github.com/Skylion007
2024-06-02 12:30:43 +00:00
PyTorch MergeBot
033e733021 Revert "[BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)"
This reverts commit 749a132fb0.

Reverted https://github.com/pytorch/pytorch/pull/126898 on behalf of https://github.com/fbgheith due to switching typing-extensions=4.3.0 to 4.9.0 causes internal failure ([comment](https://github.com/pytorch/pytorch/pull/126898#issuecomment-2142884456))
2024-05-31 19:47:24 +00:00
Xuehai Pan
749a132fb0 [BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.

Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.

UPDATE: Use `FutureWarning` instead of `DeprecationWarning`.

Resolves #126888

- #126888

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126898
Approved by: https://github.com/albanD
2024-05-29 12:09:27 +00:00
Edward Z. Yang
b581e03850 Apply UFMT to torch/distributions/distribution.py, manually resolve fstrings (#106266)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106266
Approved by: https://github.com/Skylion007
2023-07-30 19:10:57 +00:00
Aaron Gokaslan
6d43c89f37 [BE]: Update Ruff to 0.0.280 (#105724)
Removes unusued loop values in python dictionary iteration. Automated fix from Ruff master

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105724
Approved by: https://github.com/ezyang, https://github.com/janeyx99
2023-07-22 23:03:34 +00:00
Xuehai Pan
5b1cedacde [BE] [2/3] Rewrite super() calls in functorch and torch (#94588)
Rewrite Python built-in class `super()` calls. Only non-semantic changes should be applied.

- #94587
- #94588
- #94592

Also, methods with only a `super()` call are removed:

```diff
class MyModule(nn.Module):
-   def __init__(self):
-       super().__init__()
-
    def forward(self, ...):
        ...
```

Some cases that change the semantics should be kept unchanged. E.g.:

f152a79be9/caffe2/python/net_printer.py (L184-L190)

f152a79be9/test/test_jit_fuser_te.py (L2628-L2635)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94588
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-10 21:16:33 +00:00
Aaron Gokaslan
8fce9a09cd [BE]: pyupgrade Python to 3.8 - imports and object inheritance only (#94308)
Apply parts of pyupgrade to torch (starting with the safest changes).
This PR only does two things: removes the need to inherit from object and removes unused future imports.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94308
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-07 21:10:56 +00:00
Yanbo Liang
0ab4ab9f8d [Dynamo] Fix calling UserDefinedObject.func should pass self object (#92050)
Fixes #90834

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92050
Approved by: https://github.com/jansel
2023-01-21 05:47:01 +00:00
Ethan Pronovost
585d71513d Add type annotations to distribution.py (#87577)
As title.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87577
Approved by: https://github.com/kit1980
2022-10-26 18:50:48 +00:00
anjali411
3bcc19b29a Add __all__ to various submodules in torch.fx, distributions, distributed, package (#80367)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80367
Approved by: https://github.com/albanD
2022-06-27 21:27:30 +00:00
Till Hoffmann
40576bceaf Add mode property to distributions. (#76690)
This PR fixes #69466 and introduces some other minor changes. Tests are somewhat more involved because a reference implementation in `scipy` is not available; tests proceed differently for discrete and continuous distributions.

For continuous distributions, we evaluate the gradient of the `log_prob` at the mode. Tests pass if the gradient is zero OR (the mode is at the boundary of the support of the distribution AND the `log_prob` decreases as we move away from the boundary to the interior of the support).

For discrete distributions, the notion of a gradient is not well defined. We thus "look" ahead and behind one step (e.g. if the mode of a Poisson distribution is 9, we consider 8 and 10). If the step ahead/behind is still within the support of the distribution, we assert that the `log_prob` is smaller than at the mode.

For one-hot encoded distributions (currently just `OneHotCategorical`), we evaluate the underlying mode (i.e. encoded as an integral tensor), "advance" by one label to get another sample that should have lower probability using `other = (mode + 1) % event_size` and re-encode as one-hot. The resultant `other` sample should have lower probability than the mode.

Furthermore, Gamma, half Cauchy, and half normal distributions have their support changed from positive to nonnegative. This change is necessary because the mode of the "half" distributions is zero, and the mode of the gamma distribution is zero for `concentration <= 1`.

cc @fritzo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76690
Approved by: https://github.com/neerajprad
2022-05-11 18:26:56 +00:00
neerajprad
201f7d330a Remove duplicate check in distributions arg validation (#67741)
Summary:
Partial fix for https://github.com/pytorch/pytorch/issues/66800. (Duplicate of https://github.com/pytorch/pytorch/issues/67725 against pytorch/pytorch so as to trigger TorchBench)

https://github.com/pytorch/pytorch/issues/61056 added a more verbose error message for distributions failing argument validation. However, it did not replace the earlier error check as was originally intended and was flagged by xuzhao9 as being the potential cause of a perf regression in `test_eval[soft_actor_critic-cuda-eager]`.

xuzhao9: Is there a way for me to check if this resolves the perf issue you mentioned?

cc VitalyFedyunin ngimel

Note that existing tests already check for the error message and should verify that the removed lines are redundant.

RUN_TORCHBENCH: soft_actor_critic

Pull Request resolved: https://github.com/pytorch/pytorch/pull/67741

Reviewed By: neerajprad

Differential Revision: D32135675

Pulled By: xuzhao9

fbshipit-source-id: 37dfd3ff53b95017c763371979ab3a2c302a72b9
2021-11-03 10:41:41 -07:00
Fritz Obermeyer
81e36d02a6 Improve error message on invalid values to Distribution methods (#61056)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/18133

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61056

Reviewed By: jbschlosser

Differential Revision: D29510173

Pulled By: neerajprad

fbshipit-source-id: 205ec7de6c8576a73e77ee4bf01c30e99b38a52e
2021-07-06 15:44:55 -07:00
neerajprad
d76176cc1f Raise warning during validation when arg_constraints not defined (#50302)
Summary:
After we merged https://github.com/pytorch/pytorch/pull/48743, we noticed that some existing code that subclasses `torch.Distribution` started throwing `NotImplemenedError` since the constraints required for validation checks were not implemented.

```sh
File "torch/distributions/distribution.py", line 40, in __init__
  for param, constraint in self.arg_constraints.items():
File "torch/distributions/distribution.py", line 92, in arg_constraints
  raise NotImplementedError
```

This PR throws a UserWarning for such cases instead and gives a better warning message.

cc. Balandat

Pull Request resolved: https://github.com/pytorch/pytorch/pull/50302

Reviewed By: Balandat, xuzhao9

Differential Revision: D25857315

Pulled By: neerajprad

fbshipit-source-id: 0ff9f81aad97a0a184735b1fe3a5d42025c8bcdf
2021-01-11 15:26:53 -08:00
Fritz Obermeyer
093aca082e Enable distribution validation if __debug__ (#48743)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/47123
Follows https://github.com/pyro-ppl/pyro/pull/2701

This turns on `Distribution` validation by default. The motivation is to favor beginners by providing helpful error messages. Advanced users focused on speed can disable validation by calling
```py
torch.distributions.Distribution.set_default_validate_args(False)
```
or by disabling individual distribution validation via `MyDistribution(..., validate_args=False)`.

In practice I have found many beginners forget or do not know about validation. Therefore I have [enabled it by default](https://github.com/pyro-ppl/pyro/pull/2701) in Pyro. I believe PyTorch could also benefit from this change. Indeed validation caught a number of bugs in `.icdf()` methods, in tests, and in PPL benchmarks, all of which have been fixed in this PR.

## Release concerns
- This may slightly slow down some models. Concerned users may disable validation.
- This may cause new `ValueErrors` in models that rely on unsupported behavior, e.g. `Categorical.log_prob()` applied to continuous-valued tensors (only {0,1}-valued tensors are supported).

We should clearly note this change in release notes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/48743

Reviewed By: heitorschueroff

Differential Revision: D25304247

Pulled By: neerajprad

fbshipit-source-id: 8d50f28441321ae691f848c55f71aa80cb356b41
2021-01-05 13:59:10 -08:00
Xu Zhao
146721f1df Fix typing errors in the torch.distributions module (#45689)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/42979.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/45689

Reviewed By: agolynski

Differential Revision: D24229870

Pulled By: xuzhao9

fbshipit-source-id: 5fc87cc428170139962ab65b71cacba494d46130
2020-10-12 10:29:45 -07:00
Morgan Funtowicz
c596683309 Rely on numel() == 1 to check if distribution parameters are scalar. (#17503)
Summary:
As discussed here #16952, this PR aims at improving the __repr__ for distribution when the provided parameters are torch.Tensor with only one element.

Currently, __repr__() relies on dim() == 0 leading to the following behaviour :

```
>>> torch.distributions.Normal(torch.tensor([1.0]), torch.tensor([0.1]))
Normal(loc: torch.Size([1]), scale: torch.Size([1]))
```

With this PR, the output looks like the following:
```
>>> torch.distributions.Normal(torch.tensor([1.0]), torch.tensor([0.1]))
Normal(loc: 1.0, scale: 0.10000000149011612)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17503

Differential Revision: D14245439

Pulled By: soumith

fbshipit-source-id: a440998905fd60cf2ac9a94f75706021dd9ce5bf
2019-02-28 13:36:17 -08:00
Fritz Obermeyer
0d366e1bde Support multiple inheritance in torch.distributions (#16772)
Summary:
This adds calls to `super().__init__()` in three classes in torch.distributions.

This is needed when `Distribution` and `Transform` objects are used with multiple inheritance, as e.g. combined with `torch.nn.Module`s. For example
```py
class MyModule(torch.distributions.Transform, torch.nn.Module):
    ...
```
cc  martinjankowiak esling who have wanted to use this pattern, e.g. in #16756
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16772

Differential Revision: D13978633

Pulled By: soumith

fbshipit-source-id: 8bc6cca1747cd74d32135ee2fe588bba2ea796f1
2019-02-07 01:37:57 -08:00
Fritz Obermeyer
2431eac7c0 Ensure most Distribution methods are jittable (#11560)
Summary:
This adds tests in tests/test_distributions.py to ensure that all methods of `Distribution` objects are jittable.

I've replaced a few samplers with jittable versions:
- `.uniform_()` -> `torch.rand()`
- `.exponential_()` -> `-(-torch.rand()).log1p()`
- `.normal_()` -> `torch.normal(torch.zeros(...), torch.ones(...), ...)`

Some jit failures remain, and are marked in test_distributions.py
- `Cauchy` and `HalfCauchy` do not support sampling due to missing `.cauchy_()`
- `Binomial` does not support `.enumerate_support()` due to `arange` ignoring its first arg.
- `MultivariateNormal`, `LowRankMultivariateNormal` do not support `.mean`, `.entropy`

- [x] Currently some tests fail (I've skipped those) due to unavailability of `aten::uniform` and `aten::cauchy` in the jit. Can someone suggest how to add these? I tried to add declarations to `torch/csrc/ir.cpp` and `torch/csrc/passes/shape_analysis.cpp`, but that resulted in "Couldn't find operator" errors.
- [x] There are still lots of `TracerWarning`s that something doesn't match something. I'm not sure whether these are real.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11560

Differential Revision: D9816327

Pulled By: apaszke

fbshipit-source-id: 72ec998ea13fc4c76d1ed003d9502e0fbaf728b8
2018-09-13 19:55:01 -07:00
Neeraj Pradhan
80fa8e1007 Add .expand() method to distribution classes (#11341)
Summary:
This adds a `.expand` method for distributions that is akin to the `torch.Tensor.expand` method for tensors. It returns a new distribution instance with batch dimensions expanded to the desired `batch_shape`. Since this calls `torch.Tensor.expand` on the distribution's parameters, it does not allocate new memory for the expanded distribution instance's parameters.

e.g.
```python
>>> d = dist.Normal(torch.zeros(100, 1), torch.ones(100, 1))
>>> d.sample().shape
  torch.Size([100, 1])
>>> d.expand([100, 10]).sample().shape
  torch.Size([100, 10])
```

We have already been using the `.expand` method in Pyro in our [patch](https://github.com/uber/pyro/blob/dev/pyro/distributions/torch.py#L10) of `torch.distributions`. We use this in our models to enable dynamic broadcasting. This has also been requested by a few users on the distributions slack, and we believe will be useful to the larger community.

Note that currently, there is no convenient and efficient way to expand distribution instances:
 - Many distributions use `TransformedDistribution` (or wrap over another distribution instance. e.g. `OneHotCategorical` uses a `Categorical` instance) under the hood, or have lazy parameters. This makes it difficult to collect all the relevant parameters, broadcast them and construct new instances.
 - In the few cases where this is even possible, the resulting implementation would be inefficient since we will go through a lot of broadcasting and args validation logic in `__init__.py` that can be avoided.

The `.expand` method allows for a safe and efficient way to expand distribution instances. Additionally, this bypasses `__init__.py` (using `__new__` and populating relevant attributes) since we do not need to do any broadcasting or args validation (which was already done when the instance was first created). This can result in significant savings as compared to constructing new instances via `__init__` (that said, the `sample` and `log_prob` methods will probably be the rate determining steps in many applications).

e.g.
```python
>>> a = dist.Bernoulli(torch.ones([10000, 1]), validate_args=True)

>>> %timeit a.expand([10000, 100])
15.2 µs ± 224 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

>>> %timeit dist.Bernoulli(torch.ones([10000, 100]), validate_args=True)
11.8 ms ± 153 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
```

cc. fritzo, apaszke, vishwakftw, alicanb
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11341

Differential Revision: D9728485

Pulled By: soumith

fbshipit-source-id: 3b94c23bc6a43ee704389e6287aa83d1e278d52f
2018-09-11 06:56:18 -07:00
Neeraj Pradhan
b3b1e7624d Optional expand=True kwarg in distribution.enumerate_support (#11231)
Summary:
This adds an optional `expand=True` kwarg to the `distribution.expand_support()` method, to get a distribution's support without expanding the values over the distribution's `batch_shape`.
 - The default `expand=True` preserves the current behavior, whereas `expand=False` collapses the batch dimensions.

e.g.
```python
In [47]: d = dist.OneHotCategorical(torch.ones(3, 5) * 0.5)

In [48]: d.batch_shape
Out[48]: torch.Size([3])

In [49]: d.enumerate_support()
Out[49]:
tensor([[[1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0.]],

        [[0., 1., 0., 0., 0.],
         [0., 1., 0., 0., 0.],
         [0., 1., 0., 0., 0.]],

        [[0., 0., 1., 0., 0.],
         [0., 0., 1., 0., 0.],
         [0., 0., 1., 0., 0.]],

        [[0., 0., 0., 1., 0.],
         [0., 0., 0., 1., 0.],
         [0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.]]])

In [50]: d.enumerate_support().shape
Out[50]: torch.Size([5, 3, 5])

In [51]: d.enumerate_support(expand=False)
Out[51]:
tensor([[[1., 0., 0., 0., 0.]],

        [[0., 1., 0., 0., 0.]],

        [[0., 0., 1., 0., 0.]],

        [[0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 1.]]])

In [52]: d.enumerate_support(expand=False).shape
Out[52]: torch.Size([5, 1, 5])
```

**Motivation:**
 - Currently `enumerate_support` builds up tensors of size `support + batch_shape + event_shape`, but the values are *repeated* over the `batch_shape` (adding little in the way of information). This can lead to expensive matrix operations over large tensors when `batch_shape` is large (see, example above), often leading to OOM issues. We use `expand=False` in Pyro for message passing inference. e.g. when enumerating over the state space in a Hidden Markov Model. This creates sparse tensors that capture the markov dependence, and allows for the possibility of using optimized matrix operations over these sparse tensors. `expand=True`, on the other hand, will create tensors that scale exponentially in size with the length of the Markov chain.
 - We have been using this in our [patch](https://github.com/uber/pyro/blob/dev/pyro/distributions/torch.py) of `torch.distributions` in Pyro. The interface has been stable, and it is already being used in a few Pyro algorithms. We think that this is more broadly applicable and will be of interest to the larger distributions community.

cc. apaszke, fritzo, alicanb
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11231

Differential Revision: D9696290

Pulled By: soumith

fbshipit-source-id: c556f8ff374092e8366897ebe3f3b349538d9318
2018-09-06 21:39:42 -07:00
Neeraj Pradhan
434e943b08 Fix to distribution.__repr__ with lazy attributes (#11263)
Summary:
`__repr__` currently fails for distributions with lazy attributes in PyTorch master, throwing a `KeyError`. This fixes the issue.

**Additionally:**
 - Added `logits` to `arg_constraints` for distributions that accept either `probs` or `logits`. This is both to have `__repr__` display the `logits` param when available, and to be able to do validation checks (e.g. NaN checks) when the logit parametrization is used. fritzo, alicanb - I think there were reasons why we had not done so in the first place, but I am unable to recall now. It passes all the tests, but let me know if there is something that I am missing at the moment.
 - There are certain distributions, e.g. `OneHotCategorical` which won't show any parameters because it uses a `categorical` instance under the hood and neither `logits` / `probs` in `arg_constraints` are present in the instance's `__dict__`. This isn't addressed in this PR.

cc. vishwakftw, fritzo, nadavbh12, apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11263

Differential Revision: D9654959

Pulled By: apaszke

fbshipit-source-id: 16f5b20243fe8e2c13e9c528050d4df0b8ea6e45
2018-09-05 09:55:51 -07:00
nadavbh12
8a1739b05d Add arguments __repr__ in Distribution base class
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/10373

Differential Revision: D9240316

Pulled By: ezyang

fbshipit-source-id: f35c500f61f86e6be405e8bd4040db5146224984
2018-08-21 12:10:23 -07:00
Fritz Obermeyer
187955b959 [distributions] Skip validation of lazy properties (#6666) 2018-04-18 10:12:08 +02:00
Tongzhou Wang
1c01eabd3c
Codemod to update our codebase to 0.4 standard (#6641)
* Codemod to update our codebase to 0.4 standard

* Update some of the test scri[ts

* remove Variable in test_clip_grad_value

* fix _symbolic_override_wrapper_maker
2018-04-17 22:06:54 -04:00
Xingdong Zuo
65a8ac0b8e Add method to calculate perplexity of distribution (#6427) 2018-04-10 12:18:26 +02:00
Fritz Obermeyer
b2da9fd220 [distributions] Rename .params to .arg_constraints, fix logic (#5989) 2018-03-25 15:24:32 +02:00
lazypanda1
7f864bbe52 Fixed distribution constraints and added some test cases for distributions parameter check (#5358) 2018-03-15 23:11:20 +01:00
Sam Gross
54b4cdeffa
Replace all uses of 'Tensor or Variable' with 'Tensor' (#5508)
Replace all uses of 'Tensor or Variable'  and 'Variable or Tensor' with 'Tensor'
2018-03-02 14:26:11 -05:00
Sam Gross
70ba50c3d4 Remove some uses of torch.is_tensor in favor of isinstance (#5473) 2018-03-02 06:17:38 -05:00
gchanan
d5038309a1
Remove WITH_SCALARS, as it's enabled by default now. (#5437) 2018-02-27 14:51:11 -05:00
Fritz Obermeyer
a4d0a74cee Ensure Distribution.sample() result is detached (#5086) 2018-02-14 01:32:11 +01:00
Vishwak Srinivasan
011941087a Implementation of the cumulative distribution function and its inverse (#5079) 2018-02-07 16:10:19 +01:00
Alican Bozkurt
20fbdb9a8b Adding mean, variance, stddev to distributions (#4923) 2018-01-31 00:26:32 +01:00
gchanan
4970e73304
Add support for distributions and test_distributions when WITH_SCALAR… (#4834)
* Add support for distributions and test_distributions when WITH_SCALARS enabled.

* Fix flake8.
2018-01-24 19:22:05 -05:00
Vishwak Srinivasan
86fe793948 Addition of KL-Divergences for torch.distributions (#4638) 2018-01-14 22:52:28 +01:00
Fritz Obermeyer
a3e91515de Declare constraints for distribution parameters and support (#4450) 2018-01-04 23:58:26 +01:00
Fritz Obermeyer
5c33400dd3 Implement OneHotCategorical distribution (#4357) 2017-12-28 16:54:55 +01:00
Neeraj Pradhan
0c4b3f4271 Adding Uniform distribution to PyTorch (#4328) 2017-12-23 15:14:44 +01:00
Fritz Obermeyer
0bc1505f34 Implement .entropy() methods for all distributions (#4268) 2017-12-20 14:06:01 +01:00
Alican Bozkurt
94ff31f54d Implement Exponential distribution (#4234)
* add exponential distribution

* add exponential tests

* fix default val of sample_shape

* lambd->rate

* updates per review

* remove notes, keep failure_rate same in exponential test
2017-12-18 16:44:35 -05:00
Fritz Obermeyer
bcbb36e99a Allow value broadcasting in distributions.Distribution (#4210) 2017-12-18 20:11:39 +01:00