Commit Graph

58 Commits

Author SHA1 Message Date
Maggie Moss
154e4d36e9 Fix pyrelfy ignore syntax in distributions and ao (#166248)
Ensures existing pyrefly ignores only ignore the intended error code

pyrefly check
lintrunner

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166248
Approved by: https://github.com/oulgen
2025-10-26 22:13:48 +00:00
Maggie Moss
b13cd141b3 Add pyrefly suppressions (#164748)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the `project-excludes` field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:

0 errors (4,263 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164748
Approved by: https://github.com/oulgen
2025-10-07 17:31:18 +00:00
Randolf Scholz
6c38b9be73 [typing] Add type hints to __init__ methods in torch.distributions. (#144197)
Fixes #144196
Extends #144106 and #144110

## Open Problems:

- [ ] Annotating with `numbers.Number` is a bad idea, should consider using `float`, `SupportsFloat` or some `Procotol`. https://github.com/pytorch/pytorch/pull/144197#discussion_r1903324769

# Notes

- `beta.py`: needed to add `type: ignore` since `broadcast_all` is untyped.
- `categorical.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`dirichlet.py`: replaced `axis` with `dim` arguments.~~ #144402
- `gemoetric.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`independent.py`: fixed bug in `Independent.__init__` where `tuple[int, ...]` could be passed to `Distribution.__init__` instead of `torch.Size`.~~ **EDIT:** turns out the bug is related to typing of `torch.Size`. #144218
- `independent.py`: made `Independent` a generic class of its base distribution.
- `multivariate_normal.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- `relaxed_bernoulli.py`: added class-level type hint for `base_dist`.
- `relaxed_categorical.py`: added class-level type hint for `base_dist`.
- ~~`transforms.py`: Added missing argument to docstring of `ReshapeTransform`~~ #144401
- ~~`transforms.py`: Fixed bug in `AffineTransform.sign` (could return `Tensor` instead of `int`).~~ #144400
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.log_abs_det_jacobian`[^1]; replaced `torch.abs(scale)` with `scale.abs()`.
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.__eq__`[^1].
- `transforms.py`: Fixed type hint on `CumulativeDistributionTransform.domain`. Note that this is still an LSP violation, because `Transform.domain` is defined as `Constraint`, but `Distribution.domain` is defined as `Optional[Constraint]`.
- skipped: `constraints.py`, `constraints_registry.py`, `kl.py`, `utils.py`, `exp_family.py`, `__init__.py`.

## Remark

`TransformedDistribution`: `__init__` uses the check `if reinterpreted_batch_ndims > 0:`, which can lead to the creation of `Independent` distributions with only 1 component. This results in awkward code like `base_dist.base_dist` in `LogisticNormal`.

```python
import torch
from torch.distributions import *
b1 = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
b2 = MultivariateNormal(torch.tensor([0.0]), torch.eye(1))
t = StickBreakingTransform()
d1 = TransformedDistribution(b1, t)
d2 = TransformedDistribution(b2, t)
print(d1.base_dist)  # Independent with 1 dimension
print(d2.base_dist)  # MultivariateNormal
```

One could consider changing this to `if reinterpreted_batch_ndims > 1:`.

[^1]: Usage of `isinstance(value, numbers.Real)` leads to problems with static typing, as the `numbers` module is not supported by `mypy` (see <https://github.com/python/mypy/issues/3186>). This results in us having to add type-ignore comments in several places
[^2]: Otherwise, we would have to add a bunch of `type: ignore` comments to make `mypy` happy, as it isn't able to perform the type narrowing. Ideally, such code should be replaced with structural pattern matching once support for Python 3.9 is dropped.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144197
Approved by: https://github.com/malfet

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
2025-04-06 17:50:35 +00:00
Xuehai Pan
995df34b19 [BE][PYFMT] migrate PYFMT for torch.{distributed,distributions} to ruff format (#144547)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144547
Approved by: https://github.com/kwen2501
2025-02-28 07:35:56 +00:00
Randolf Scholz
355b0bc7e3 [typing] Add type hints to @property and @lazy_property in torch.distributions. (#144110)
Fixes #76772, #144196
Extends #144106

- added type annotations to `lazy_property`.
- added type annotation to all `@property` and `@lazy_property` inside `torch.distributions` module.
- added simply type-check unit test to ensure type inference is working.
- replaced deprecated annotations like `typing.List` with the corresponding counterpart.
- simplified `torch.Tensor` hints with plain `Tensor`, otherwise signatures can become very verbose.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144110
Approved by: https://github.com/Skylion007
2025-01-07 19:27:36 +00:00
Xuehai Pan
b25ef91bf1 [BE][Easy][18/19] enforce style for empty lines in import segments in torch/d*/ (#129770)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129770
Approved by: https://github.com/wconstab
2024-08-01 04:22:50 +00:00
Aaron Orenstein
7c12cc7ce4 Flip default value for mypy disallow_untyped_defs [6/11] (#127843)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127843
Approved by: https://github.com/oulgen
ghstack dependencies: #127842
2024-06-08 18:49:29 +00:00
a-r-r-o-w
e08577aec5 Spelling fix (#108490)
Fixes spelling mistake: non-deterinistic -> non-deterministic
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108490
Approved by: https://github.com/ezyang
2023-09-04 16:59:35 +00:00
Edward Z. Yang
3bf922a6ce Apply UFMT to low traffic torch modules (#106249)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106249
Approved by: https://github.com/Skylion007
2023-07-29 23:37:30 +00:00
Xuehai Pan
b005ec62b9 [BE] Remove dependency on six and future (#94709)
Remove the Python 2 and 3 compatibility library [six](https://pypi.org/project/six) and [future](https://pypi.org/project/future) and `torch._six`. We only support Python 3.8+ now. It's time to retire them.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94709
Approved by: https://github.com/malfet, https://github.com/Skylion007
2023-02-14 09:14:14 +00:00
Xuehai Pan
5b1cedacde [BE] [2/3] Rewrite super() calls in functorch and torch (#94588)
Rewrite Python built-in class `super()` calls. Only non-semantic changes should be applied.

- #94587
- #94588
- #94592

Also, methods with only a `super()` call are removed:

```diff
class MyModule(nn.Module):
-   def __init__(self):
-       super().__init__()
-
    def forward(self, ...):
        ...
```

Some cases that change the semantics should be kept unchanged. E.g.:

f152a79be9/caffe2/python/net_printer.py (L184-L190)

f152a79be9/test/test_jit_fuser_te.py (L2628-L2635)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94588
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-10 21:16:33 +00:00
joncrall
4618371da5 Integrate xdoctest - Rebased (#82797)
This is a new version of #15648 based on the latest master branch.

Unlike the previous PR where I fixed a lot of the doctests in addition to integrating xdoctest, I'm going to reduce the scope here. I'm simply going to integrate xdoctest, and then I'm going to mark all of the failing tests as "SKIP". This will let xdoctest run on the dashboards, provide some value, and still let the dashboards pass. I'll leave fixing the doctests themselves to another PR.

In my initial commit, I do the bare minimum to get something running with failing dashboards. The few tests that I marked as skip are causing segfaults. Running xdoctest results in 293 failed, 201 passed tests. The next commits will be to disable those tests. (unfortunately I don't have a tool that will insert the `#xdoctest: +SKIP` directive over every failing test, so I'm going to do this mostly manually.)

Fixes https://github.com/pytorch/pytorch/issues/71105

@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82797
Approved by: https://github.com/ezyang
2022-08-12 02:08:01 +00:00
anjali411
3bcc19b29a Add __all__ to various submodules in torch.fx, distributions, distributed, package (#80367)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80367
Approved by: https://github.com/albanD
2022-06-27 21:27:30 +00:00
Till Hoffmann
40576bceaf Add mode property to distributions. (#76690)
This PR fixes #69466 and introduces some other minor changes. Tests are somewhat more involved because a reference implementation in `scipy` is not available; tests proceed differently for discrete and continuous distributions.

For continuous distributions, we evaluate the gradient of the `log_prob` at the mode. Tests pass if the gradient is zero OR (the mode is at the boundary of the support of the distribution AND the `log_prob` decreases as we move away from the boundary to the interior of the support).

For discrete distributions, the notion of a gradient is not well defined. We thus "look" ahead and behind one step (e.g. if the mode of a Poisson distribution is 9, we consider 8 and 10). If the step ahead/behind is still within the support of the distribution, we assert that the `log_prob` is smaller than at the mode.

For one-hot encoded distributions (currently just `OneHotCategorical`), we evaluate the underlying mode (i.e. encoded as an integral tensor), "advance" by one label to get another sample that should have lower probability using `other = (mode + 1) % event_size` and re-encode as one-hot. The resultant `other` sample should have lower probability than the mode.

Furthermore, Gamma, half Cauchy, and half normal distributions have their support changed from positive to nonnegative. This change is necessary because the mode of the "half" distributions is zero, and the mode of the gamma distribution is zero for `concentration <= 1`.

cc @fritzo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76690
Approved by: https://github.com/neerajprad
2022-05-11 18:26:56 +00:00
lezcano
4e347f1242 [docs] Fix backticks in docs (#60474)
Summary:
There is a very common error when writing docs: One forgets to write a matching `` ` ``, and something like ``:attr:`x`` is rendered in the docs. This PR fixes most (all?) of these errors (and a few others).

I found these running ``grep -r ">[^#<][^<]*\`"`` on the `docs/build/html/generated` folder. The regex finds an HTML tag that does not start with `#` (as python comments in example code may contain backticks) and that contains a backtick in the rendered HTML.

This regex has not given any false positive in the current codebase, so I am inclined to suggest that we should add this check to the CI. Would this be possible / reasonable / easy to do malfet ?

Pull Request resolved: https://github.com/pytorch/pytorch/pull/60474

Reviewed By: mrshenli

Differential Revision: D29309633

Pulled By: albanD

fbshipit-source-id: 9621e0e9f87590cea060dd084fa367442b6bd046
2021-06-24 06:27:41 -07:00
neerajprad
e2041ce354 Fix docstring to clarify logits usage for multiclass case (#51053)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/50378.

Additionally, this has some minor fixes:
 - [x] Fix mean for half-cauchy to return `inf` instead of `nan`.
 - [x] Fix constraints/support for the relaxed categorical distribution.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/51053

Reviewed By: heitorschueroff

Differential Revision: D26077966

Pulled By: neerajprad

fbshipit-source-id: ca0213baa9bbdbc661aebbb901ab5e7fded38a5f
2021-01-26 17:01:39 -08:00
Fritz Obermeyer
a347c747df Fix TransformedDistribution shaping logic (#50581)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/50496
Fixes https://github.com/pytorch/pytorch/issues/34859
Fixes https://github.com/pytorch/pytorch/issues/21596

This fixes many bugs involving `TransformedDistribution` and `ComposeTransform` when the component transforms changed their event shapes. Part of the fix is to introduce an `IndependentTransform` analogous to `distributions.Independent` and `constraints.independent`, and to introduce methods `Transform.forward_shape()` and `.inverse_shape()`. I have followed fehiepsi's suggestion and replaced `.input_event_dim` -> `.domain.event_dim` and `.output_event_dim` -> `.codomain.event_dim`. This allows us to deprecate `.event_dim` as an attribute.

## Summary of changes

- Fixes `TransformDistribution` and `ComposeTransform` shape errors.
- Fixes a behavior bug in `LogisticNormal`.
- Fixes `kl_divergence(TransformedDistribution, TransformedDistribution)`
- Adds methods `Transform.forward_shape()`, `.inverse_shape()` which are required for correct shape computations in `TransformedDistribution` and `ComposeTransform`.
- Adds an `IndependentTransform`.
- Adds a `ReshapeTransform` which is invaluable in testing shape logic in `ComposeTransform` and `TransformedDistribution` and which will be used by stefanwebb flowtorch.
- Fixes incorrect default values in `constraints.dependent.event_dim`.
- Documents the `.event_dim` and `.is_discrete` attributes.

## Changes planned for follow-up PRs

- Memoize `constraints.dependent_property` as we do with `lazy_property`, since we now consult those properties much more often.

## Tested
- [x] added a test for `Dist.support` vs `Dist(**params).support` to ensure static and dynamic attributes agree.
- [x] refactoring is covered by existing tests
- [x] add test cases for `ReshapedTransform`
- [x] add a test for `TransformedDistribution` on a wide grid of input shapes
- [x] added a regression test for https://github.com/pytorch/pytorch/issues/34859

cc fehiepsi feynmanliang stefanwebb

Pull Request resolved: https://github.com/pytorch/pytorch/pull/50581

Reviewed By: ezyang, glaringlee, jpchen

Differential Revision: D26024247

Pulled By: neerajprad

fbshipit-source-id: f0b9a296f780ff49659b132409e11a29985dde9b
2021-01-25 16:34:12 -08:00
Fritz Obermeyer
21c2542b6a Independent constraint (#50547)
Summary:
Addresses https://github.com/pytorch/pytorch/issues/50496

This fixes a number of inconsistencies in torch.distributions.constraints as used for parameters and supports of probability distributions.
- Adds a `constraints.independent` and replaces `real_vector` with `independent(real, 1)`. (this pattern has long been used in Pyro)
- Adds an `.event_dim` attribute to all constraints.
- Tests that `constraint.check(data)` has the correct shape. (Previously the shapes were incorrect).
- Adds machinery to set static `.is_discrete` and `.event_dim` for `constraints.dependent`.
- Fixes constraints for a number of distributions.

## Tested
- added a new check to the constraints tests
- added a new check for `.event_dim`

cc fehiepsi feynmanliang stefanwebb

Pull Request resolved: https://github.com/pytorch/pytorch/pull/50547

Reviewed By: VitalyFedyunin

Differential Revision: D25918330

Pulled By: neerajprad

fbshipit-source-id: a648c3de3e8704f70f445c0f1c39f2593c8c74db
2021-01-21 18:42:45 -08:00
Eric Cotner
e4efc420ae Correct Categorical docstring (#45804)
Summary:
Clarified that the `Categorical` distribution will actually accept input of any arbitrary tensor shape, not just 1D and 2D tensors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/45804

Reviewed By: dzhulgakov

Differential Revision: D24125415

Pulled By: VitalyFedyunin

fbshipit-source-id: 5fa1f07911bd85e172199b28d79763428db3a0f4
2020-10-05 21:49:10 -07:00
mattip
a0f110190c clamp Categorical logit from -inf to min_fifo when calculating entropy (#41002)
Summary:
Fixes gh-40553 by clamping logit values when calculating Categorical.entropy

Pull Request resolved: https://github.com/pytorch/pytorch/pull/41002

Reviewed By: mruberry

Differential Revision: D22436432

Pulled By: ngimel

fbshipit-source-id: 08b7c7b0c15ab4e5a56b3a8ec0d0237ad360202e
2020-07-14 16:21:12 -07:00
Neeraj Pradhan
e43c2d59dd Reduce memory overhead of categorical.sample (#34900)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/34714 (using the discussed solution). Thanks to jjabo for flagging and suggesting this.

Instead of expanding `probs` to prepend `sample_shape`, it is better  to use the `num_samples` argument to `torch.multinomial` instead, which is faster and consumes lesser memory.

Existing tests should cover this. I have profiled this on different inputs and the change results in faster `.sample` (e.g. 100X faster on the example in the issue), or at worst is similar to what we have now with the default `sample_shape` argument.

cc. fritzo, alicanb, ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34900

Differential Revision: D20499065

Pulled By: ngimel

fbshipit-source-id: e5be225e3e219bd268f5f635aaa9bf7eca39f09c
2020-03-17 17:49:41 -07:00
Fritz Obermeyer
06762b4721 Fix distributions.Categorical.sample bug from .view() (#23328)
Summary:
This modernizes distributions code by replacing a few uses of `.contiguous().view()` with `.reshape()`, fixing a sample bug in the `Categorical` distribution.

The bug is exercised by the following test:
```py
batch_shape = (1, 2, 1, 3, 1)
sample_shape = (4,)
cardinality = 2
logits = torch.randn(batch_shape + (cardinality,))
dist.Categorical(logits=logits).sample(sample_shape)
# RuntimeError: invalid argument 2: view size is not compatible with
#   input tensor's size and stride (at least one dimension spans across
#   two contiguous subspaces). Call .contiguous() before .view().
#   at ../aten/src/TH/generic/THTensor.cpp:203
```
I have verified this works locally, but I have not added this as a regression test because it is unlikely to regress (the code is now simpler).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23328

Differential Revision: D16510678

Pulled By: colesbury

fbshipit-source-id: c125c1a37d21d185132e8e8b65241c86ad8ad04b
2019-07-29 12:09:50 -07:00
James Malcolm
42770e1370 Improving Categorical Distribution Docs' (#16291) (#21707)
Summary:
**Closes:** Confusing documentation with distributions.Categorical about logits https://github.com/pytorch/pytorch/issues/16291

**Solution**: Changes documentation on the Categorical distribution from `log probabilities` to `event log-odds`. This makes should reduce confusion as raised by this issue, and is consistent with other distributions such as `torch.Binomial`.

More than happy to make any other changes if they fit :).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21707

Differential Revision: D15799181

Pulled By: soumith

fbshipit-source-id: f11acca7a5c130102a3ff6674640235ee5aa69bf
2019-06-12 23:54:02 -07:00
Neeraj Pradhan
fb40e58f24 Remove deprecated tensor constructors in torch.distributions (#19979)
Summary:
This removes the deprecated `tensor.new_*` constructors (see #16770) from `torch.distributions` module.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19979

Differential Revision: D15195618

Pulled By: soumith

fbshipit-source-id: 46b519bfd32017265e90bd5c53f12cfe4a138021
2019-05-02 20:45:02 -07:00
Ahmad Salim Al-Sibahi
8e1e29124d Fix precision issue with expansion that prefers 'probs' over 'logits' (#18614)
Summary:
I have experienced that sometimes both were in `__dict__`, but it chose to copy `probs` which loses precision over `logits`. This is especially important when training (bayesian) neural networks or doing other type of optimization, since the loss is heavily affected.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18614

Differential Revision: D14793486

Pulled By: ezyang

fbshipit-source-id: d4ff5e34fbb4021ea9de9f58af09a7de00d80a63
2019-04-05 13:07:01 -07:00
Edward Yang
173f224570 Turn on F401: Unused import warning. (#18598)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18598
ghimport-source-id: c74597e5e7437e94a43c163cee0639b20d0d0c6a

Stack from [ghstack](https://github.com/ezyang/ghstack):
* **#18598 Turn on F401: Unused import warning.**

This was requested by someone at Facebook; this lint is turned
on for Facebook by default.  "Sure, why not."

I had to noqa a number of imports in __init__.  Hypothetically
we're supposed to use __all__ in this case, but I was too lazy
to fix it.  Left for future work.

Be careful!  flake8-2 and flake8-3 behave differently with
respect to import resolution for # type: comments.  flake8-3 will
report an import unused; flake8-2 will not.  For now, I just
noqa'd all these sites.

All the changes were done by hand.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Differential Revision: D14687478

fbshipit-source-id: 30d532381e914091aadfa0d2a5a89404819663e3
2019-03-30 09:01:17 -07:00
Neeraj Pradhan
cda71e2600 Disallow scalar parameters in Dirichlet and Categorical (#11589)
Summary:
This adds a small check in `Dirichlet` and `Categorical` `__init__` methods to ensure that scalar parameters are not admissible.

**Motivation**
Currently, `Dirichlet` throws no error when provided with a scalar parameter, but if we `expand` a scalar instance, it inherits the empty event shape from the original instance and gives unexpected results.

The alternative to this check is to promote `event_shape` to be `torch.Size((1,))` if the original instance was a scalar, but that seems to add a bit more complexity (and changes the behavior of `expand` in that it would affect the `event_shape` as well as the `batch_shape` now). Does this seem reasonable? cc. alicanb, fritzo.

```python
In [4]: d = dist.Dirichlet(torch.tensor(1.))

In [5]: d.sample()
Out[5]: tensor(1.0000)

In [6]: d.log_prob(d.sample())
Out[6]: tensor(0.)

In [7]: e = d.expand([3])

In [8]: e.sample()
Out[8]: tensor([0.3953, 0.1797, 0.4250])  # interpreted as events

In [9]: e.log_prob(e.sample())
Out[9]: tensor(0.6931)  # wrongly summed out

In [10]: e.batch_shape
Out[10]: torch.Size([3])

In [11]: e.event_shape
Out[11]: torch.Size([])  # cannot be empty
```

Additionally, based on review comments, this removes `real_vector` constraint. This was only being used in `MultivariateNormal`, but I am happy to revert this if we want to keep it around for backwards compatibility.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11589

Differential Revision: D9818271

Pulled By: soumith

fbshipit-source-id: f9bbba90ed6f04e0b5bdfa169e70ca20b280fc74
2018-09-14 07:55:35 -07:00
Fritz Obermeyer
bbf54ea37c Ensure .enumerate_support() methods are jittable (#11542)
Summary:
This works around #11535 by avoiding `arange(n, out=x)` and `eye(n, out=x)` in `torch.distributions`. I've confirmed that the `.enumerate_support()` methods are now jittable.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11542

Differential Revision: D9777805

Pulled By: apaszke

fbshipit-source-id: fa38f2f1acfc0a289f725fd8c92478573cfdbefb
2018-09-11 18:26:09 -07:00
Neeraj Pradhan
80fa8e1007 Add .expand() method to distribution classes (#11341)
Summary:
This adds a `.expand` method for distributions that is akin to the `torch.Tensor.expand` method for tensors. It returns a new distribution instance with batch dimensions expanded to the desired `batch_shape`. Since this calls `torch.Tensor.expand` on the distribution's parameters, it does not allocate new memory for the expanded distribution instance's parameters.

e.g.
```python
>>> d = dist.Normal(torch.zeros(100, 1), torch.ones(100, 1))
>>> d.sample().shape
  torch.Size([100, 1])
>>> d.expand([100, 10]).sample().shape
  torch.Size([100, 10])
```

We have already been using the `.expand` method in Pyro in our [patch](https://github.com/uber/pyro/blob/dev/pyro/distributions/torch.py#L10) of `torch.distributions`. We use this in our models to enable dynamic broadcasting. This has also been requested by a few users on the distributions slack, and we believe will be useful to the larger community.

Note that currently, there is no convenient and efficient way to expand distribution instances:
 - Many distributions use `TransformedDistribution` (or wrap over another distribution instance. e.g. `OneHotCategorical` uses a `Categorical` instance) under the hood, or have lazy parameters. This makes it difficult to collect all the relevant parameters, broadcast them and construct new instances.
 - In the few cases where this is even possible, the resulting implementation would be inefficient since we will go through a lot of broadcasting and args validation logic in `__init__.py` that can be avoided.

The `.expand` method allows for a safe and efficient way to expand distribution instances. Additionally, this bypasses `__init__.py` (using `__new__` and populating relevant attributes) since we do not need to do any broadcasting or args validation (which was already done when the instance was first created). This can result in significant savings as compared to constructing new instances via `__init__` (that said, the `sample` and `log_prob` methods will probably be the rate determining steps in many applications).

e.g.
```python
>>> a = dist.Bernoulli(torch.ones([10000, 1]), validate_args=True)

>>> %timeit a.expand([10000, 100])
15.2 µs ± 224 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

>>> %timeit dist.Bernoulli(torch.ones([10000, 100]), validate_args=True)
11.8 ms ± 153 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
```

cc. fritzo, apaszke, vishwakftw, alicanb
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11341

Differential Revision: D9728485

Pulled By: soumith

fbshipit-source-id: 3b94c23bc6a43ee704389e6287aa83d1e278d52f
2018-09-11 06:56:18 -07:00
Neeraj Pradhan
b3b1e7624d Optional expand=True kwarg in distribution.enumerate_support (#11231)
Summary:
This adds an optional `expand=True` kwarg to the `distribution.expand_support()` method, to get a distribution's support without expanding the values over the distribution's `batch_shape`.
 - The default `expand=True` preserves the current behavior, whereas `expand=False` collapses the batch dimensions.

e.g.
```python
In [47]: d = dist.OneHotCategorical(torch.ones(3, 5) * 0.5)

In [48]: d.batch_shape
Out[48]: torch.Size([3])

In [49]: d.enumerate_support()
Out[49]:
tensor([[[1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0.]],

        [[0., 1., 0., 0., 0.],
         [0., 1., 0., 0., 0.],
         [0., 1., 0., 0., 0.]],

        [[0., 0., 1., 0., 0.],
         [0., 0., 1., 0., 0.],
         [0., 0., 1., 0., 0.]],

        [[0., 0., 0., 1., 0.],
         [0., 0., 0., 1., 0.],
         [0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.]]])

In [50]: d.enumerate_support().shape
Out[50]: torch.Size([5, 3, 5])

In [51]: d.enumerate_support(expand=False)
Out[51]:
tensor([[[1., 0., 0., 0., 0.]],

        [[0., 1., 0., 0., 0.]],

        [[0., 0., 1., 0., 0.]],

        [[0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 1.]]])

In [52]: d.enumerate_support(expand=False).shape
Out[52]: torch.Size([5, 1, 5])
```

**Motivation:**
 - Currently `enumerate_support` builds up tensors of size `support + batch_shape + event_shape`, but the values are *repeated* over the `batch_shape` (adding little in the way of information). This can lead to expensive matrix operations over large tensors when `batch_shape` is large (see, example above), often leading to OOM issues. We use `expand=False` in Pyro for message passing inference. e.g. when enumerating over the state space in a Hidden Markov Model. This creates sparse tensors that capture the markov dependence, and allows for the possibility of using optimized matrix operations over these sparse tensors. `expand=True`, on the other hand, will create tensors that scale exponentially in size with the length of the Markov chain.
 - We have been using this in our [patch](https://github.com/uber/pyro/blob/dev/pyro/distributions/torch.py) of `torch.distributions` in Pyro. The interface has been stable, and it is already being used in a few Pyro algorithms. We think that this is more broadly applicable and will be of interest to the larger distributions community.

cc. apaszke, fritzo, alicanb
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11231

Differential Revision: D9696290

Pulled By: soumith

fbshipit-source-id: c556f8ff374092e8366897ebe3f3b349538d9318
2018-09-06 21:39:42 -07:00
Neeraj Pradhan
434e943b08 Fix to distribution.__repr__ with lazy attributes (#11263)
Summary:
`__repr__` currently fails for distributions with lazy attributes in PyTorch master, throwing a `KeyError`. This fixes the issue.

**Additionally:**
 - Added `logits` to `arg_constraints` for distributions that accept either `probs` or `logits`. This is both to have `__repr__` display the `logits` param when available, and to be able to do validation checks (e.g. NaN checks) when the logit parametrization is used. fritzo, alicanb - I think there were reasons why we had not done so in the first place, but I am unable to recall now. It passes all the tests, but let me know if there is something that I am missing at the moment.
 - There are certain distributions, e.g. `OneHotCategorical` which won't show any parameters because it uses a `categorical` instance under the hood and neither `logits` / `probs` in `arg_constraints` are present in the instance's `__dict__`. This isn't addressed in this PR.

cc. vishwakftw, fritzo, nadavbh12, apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11263

Differential Revision: D9654959

Pulled By: apaszke

fbshipit-source-id: 16f5b20243fe8e2c13e9c528050d4df0b8ea6e45
2018-09-05 09:55:51 -07:00
vishwakftw
f940af6293 Bag of Distributions doc fixes (#10894)
Summary:
- Added `__repr__` for Constraints and Transforms.
- Arguments passed to the constructor are now rendered with :attr:

Closes https://github.com/pytorch/pytorch/issues/10884
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10894

Differential Revision: D9514161

Pulled By: apaszke

fbshipit-source-id: 4abf60335d876449f2b6477eb9655afed9d5b80b
2018-08-27 09:55:27 -07:00
Fritz Obermeyer
fcfb1c1979 Make more distributions jittable
Summary:
This uses zou3519's new `torch.broadcast_tensors()` #10075 to make `Categorical.log_prob()` and the `*Normal.__init__()` methods jittable. Previously `.log_prob()` was failing due to calls to `torch._C.infer_size()` with errors like
```
    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
>       value_shape = torch._C._infer_size(value.size(), self.batch_shape) if self.batch_shape else value.size()
E       RuntimeError: expected int at position 0, but got: Tensor
```
After this change I'm able to jit many more of Pyro's tests.

Reviewed By: ezyang

Differential Revision: D9477487

Pulled By: apaszke

fbshipit-source-id: 5f39b29c6b8fa606ad30b02fefe2dfb618e883d6
2018-08-23 08:09:49 -07:00
Tongzhou Wang
27455e9c78 Use _six for inf and nan (#9500)
Summary:
Things like `float('inf')` are actually quite expensive.
```py
In [1]: import math

In [2]: %timeit -n 200 math.inf
49.3 ns ± 1.42 ns per loop (mean ± std. dev. of 7 runs, 200 loops each)

In [3]: %timeit -n 200 float('inf')
194 ns ± 39.1 ns per loop (mean ± std. dev. of 7 runs, 200 loops each)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9500

Reviewed By: soumith

Differential Revision: D8876229

Pulled By: SsnL

fbshipit-source-id: 78602b76bb53d5588910b58270930c0bd413d2d7
2018-07-18 10:40:29 -07:00
Alican Bozkurt
8766daeec9 Refactor _log_sum_exp (#9173)
Summary:
This PR removes `distributions.utils._log_sum_exp` in favor of `torch.logsumexp`. Also fixes some warnings with `reduce` arg. in `binary_cross_entropy_with_logits`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9173

Reviewed By: SsnL

Differential Revision: D8764174

Pulled By: ezyang

fbshipit-source-id: b9c4136dbf0182e8ae77082e6448d23a430d5cb6
2018-07-15 17:40:53 -07:00
Thomas Viehmann
3799b10c44 various documentation formatting (#9359)
Summary:
This is a grab-bag of documentation formatting fixes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9359

Differential Revision: D8831400

Pulled By: soumith

fbshipit-source-id: 8dac02303168b2ea365e23938ee528d8e8c9f9b7
2018-07-13 02:48:25 -07:00
Will Feng
467fc3c436 [READY TO MERGE] Improve docs for Multinomial and Categorical distributions (#8472)
* Improve docs for Multinomial and Categorical distributions

* more improvement

* more improvement
2018-06-14 12:47:35 -04:00
Vishwak Srinivasan
61f61de270 Expose logsumexp docs and mark log_sum_exp in distributions for internal use (#8428) 2018-06-13 12:27:58 -04:00
Vishwak Srinivasan
3cbaa6b785 [ready] Clean up torch.distributions (#8046) 2018-06-02 16:54:53 +02:00
li-roy
d564ecb4a5 Update docs with new tensor repr (#6454)
* Update docs with new tensor repr

* remove cuda in dtype

* remove changes to gloo submodule

* [docs] document tensor.new_* ctor

* [docs] Add docs for tensor.to(), tensor.float(), etc

* [docs] Moar examples for docs.

* [docs] Warning for tensor ctor copy behavior

* Quick fix

* [docs] Document requires_grad_()

* [docs] Add example for requires_grad_()

* update slogdet and *fft

* update tensor rst

* small fixes

* update some docs

* additional doc changes

* update torch and tensor docs

* finish changing tensor docs

* fix flake8

* slogdet with negative det

* Update functional.py tensor ctors

* Fix nll_loss docs

* reorder to move device up

* torch.LongTensor -> torch.tensor or torch.empty in docs

* update tensor constructors in docs

* change tensor constructors

* change constructors

* change more Tensor() to tensor()

* Show requires_grads_ docs

* Fix set_default_dtype docs

* Update docs with new tensor repr

* remove cuda in dtype

* remove changes to gloo submodule

* [docs] document tensor.new_* ctor

* [docs] Add docs for tensor.to(), tensor.float(), etc

* [docs] Moar examples for docs.

* [docs] Warning for tensor ctor copy behavior

* Quick fix

* [docs] Document requires_grad_()

* [docs] Add example for requires_grad_()

* update slogdet and *fft

* update tensor rst

* small fixes

* update some docs

* additional doc changes

* update torch and tensor docs

* finish changing tensor docs

* fix flake8

* slogdet with negative det

* Update functional.py tensor ctors

* Fix nll_loss docs

* reorder to move device up

* torch.LongTensor -> torch.tensor or torch.empty in docs

* update tensor constructors in docs

* change tensor constructors

* change constructors

* change more Tensor() to tensor()

* Show requires_grads_ docs

* Fix set_default_dtype docs

* Link to torch.no_grad, etc, from torch doc

* Add dtype aliases to table

* regen docs again

* Tensor attributes stub page

* link to inplace sampling

* Link torch.dtype, device, and layout

* fix dots after nonfinite floats

* better layout docs
2018-04-21 07:35:37 -04:00
Tongzhou Wang
1c01eabd3c
Codemod to update our codebase to 0.4 standard (#6641)
* Codemod to update our codebase to 0.4 standard

* Update some of the test scri[ts

* remove Variable in test_clip_grad_value

* fix _symbolic_override_wrapper_maker
2018-04-17 22:06:54 -04:00
Fritz Obermeyer
b2da9fd220 [distributions] Rename .params to .arg_constraints, fix logic (#5989) 2018-03-25 15:24:32 +02:00
lazypanda1
7f864bbe52 Fixed distribution constraints and added some test cases for distributions parameter check (#5358) 2018-03-15 23:11:20 +01:00
Sam Gross
54b4cdeffa
Replace all uses of 'Tensor or Variable' with 'Tensor' (#5508)
Replace all uses of 'Tensor or Variable'  and 'Variable or Tensor' with 'Tensor'
2018-03-02 14:26:11 -05:00
Sam Gross
30ec06c140
Merge Variable and Tensor classes (#5225)
This replaces the torch.Tensor constructors with factories that produce
Variables. Similarly, functions on the torch module (e.g. torch.randn)
now return Variables.

To keep the PR to a reasonable size, I've left most of the unused tensor
code. Subsequent PRs will remove the dead code, clean-up calls to
torch.autograd.Variable, and rename Variable to Tensor everywhere.

There are some breaking changes because Variable and Tensors had
slightly different semantics. There's a list of those changes here:

 https://github.com/pytorch/pytorch/wiki/Breaking-Changes-from-Variable-and-Tensor-merge
2018-02-23 18:03:31 -05:00
gchanan
3b63e552f9
Fix test_distributions when WITH_SCALARS. (#5121)
* Fix test_distributions when WITH_SCALARS.

* Use SCALAR_SHAPE in test, use self.scale in AffineTransform.

* Handle device correctly for scalars.

* Fix one hot categorical.

* Fix relaxed categorical.

* Add a new_tensor instance method to Variable that takes only data.

This is to work around the legacy problems of new, where e.g.
new(5) will give you an unfilled tensor rather than a scalar.

* Fix cuda scalar code path.

* Remove double return.

* Work around lack of WITH_SCALARS.

* Use tensor_new.
2018-02-09 11:01:13 -05:00
Tongzhou Wang
47ee86776e Fix CPU torch.multinomial with noncontiguous prob tensor (#5093)
* fix CPU torch.multinomial not working on noncontiguous probability distn'

* address comments

* change some tabs to spaces in THStorage.c
2018-02-06 22:11:43 -05:00
Alican Bozkurt
20fbdb9a8b Adding mean, variance, stddev to distributions (#4923) 2018-01-31 00:26:32 +01:00
Neeraj Pradhan
b37aa2bf0e Ensure lazy evaluation for probs and logits (#4691) 2018-01-17 22:36:40 +01:00
Alican Bozkurt
9b6441ecbc Implement Multinomial distribution (#4624) 2018-01-13 11:26:14 +01:00