Commit Graph

34 Commits

Author SHA1 Message Date
Randolf Scholz
6c38b9be73 [typing] Add type hints to __init__ methods in torch.distributions. (#144197)
Fixes #144196
Extends #144106 and #144110

## Open Problems:

- [ ] Annotating with `numbers.Number` is a bad idea, should consider using `float`, `SupportsFloat` or some `Procotol`. https://github.com/pytorch/pytorch/pull/144197#discussion_r1903324769

# Notes

- `beta.py`: needed to add `type: ignore` since `broadcast_all` is untyped.
- `categorical.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`dirichlet.py`: replaced `axis` with `dim` arguments.~~ #144402
- `gemoetric.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- ~~`independent.py`: fixed bug in `Independent.__init__` where `tuple[int, ...]` could be passed to `Distribution.__init__` instead of `torch.Size`.~~ **EDIT:** turns out the bug is related to typing of `torch.Size`. #144218
- `independent.py`: made `Independent` a generic class of its base distribution.
- `multivariate_normal.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2].
- `relaxed_bernoulli.py`: added class-level type hint for `base_dist`.
- `relaxed_categorical.py`: added class-level type hint for `base_dist`.
- ~~`transforms.py`: Added missing argument to docstring of `ReshapeTransform`~~ #144401
- ~~`transforms.py`: Fixed bug in `AffineTransform.sign` (could return `Tensor` instead of `int`).~~ #144400
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.log_abs_det_jacobian`[^1]; replaced `torch.abs(scale)` with `scale.abs()`.
- `transforms.py`: Added `type: ignore` comments to `AffineTransform.__eq__`[^1].
- `transforms.py`: Fixed type hint on `CumulativeDistributionTransform.domain`. Note that this is still an LSP violation, because `Transform.domain` is defined as `Constraint`, but `Distribution.domain` is defined as `Optional[Constraint]`.
- skipped: `constraints.py`, `constraints_registry.py`, `kl.py`, `utils.py`, `exp_family.py`, `__init__.py`.

## Remark

`TransformedDistribution`: `__init__` uses the check `if reinterpreted_batch_ndims > 0:`, which can lead to the creation of `Independent` distributions with only 1 component. This results in awkward code like `base_dist.base_dist` in `LogisticNormal`.

```python
import torch
from torch.distributions import *
b1 = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
b2 = MultivariateNormal(torch.tensor([0.0]), torch.eye(1))
t = StickBreakingTransform()
d1 = TransformedDistribution(b1, t)
d2 = TransformedDistribution(b2, t)
print(d1.base_dist)  # Independent with 1 dimension
print(d2.base_dist)  # MultivariateNormal
```

One could consider changing this to `if reinterpreted_batch_ndims > 1:`.

[^1]: Usage of `isinstance(value, numbers.Real)` leads to problems with static typing, as the `numbers` module is not supported by `mypy` (see <https://github.com/python/mypy/issues/3186>). This results in us having to add type-ignore comments in several places
[^2]: Otherwise, we would have to add a bunch of `type: ignore` comments to make `mypy` happy, as it isn't able to perform the type narrowing. Ideally, such code should be replaced with structural pattern matching once support for Python 3.9 is dropped.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144197
Approved by: https://github.com/malfet

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
2025-04-06 17:50:35 +00:00
Xuehai Pan
995df34b19 [BE][PYFMT] migrate PYFMT for torch.{distributed,distributions} to ruff format (#144547)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144547
Approved by: https://github.com/kwen2501
2025-02-28 07:35:56 +00:00
Randolf Scholz
64cd81712d torch.distributions: replace numbers.Number with torch.types.Number. (#145086)
Fixes #144788 (partial)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145086
Approved by: https://github.com/malfet
2025-01-27 20:24:55 +00:00
Randolf Scholz
355b0bc7e3 [typing] Add type hints to @property and @lazy_property in torch.distributions. (#144110)
Fixes #76772, #144196
Extends #144106

- added type annotations to `lazy_property`.
- added type annotation to all `@property` and `@lazy_property` inside `torch.distributions` module.
- added simply type-check unit test to ensure type inference is working.
- replaced deprecated annotations like `typing.List` with the corresponding counterpart.
- simplified `torch.Tensor` hints with plain `Tensor`, otherwise signatures can become very verbose.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144110
Approved by: https://github.com/Skylion007
2025-01-07 19:27:36 +00:00
Christopher Yeh
e72e924eb5 Add correct typing annotations to rsample() for all distributions (#133516)
Fixes #133514
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133516
Approved by: https://github.com/Skylion007
2024-08-18 20:31:54 +00:00
Xuehai Pan
b25ef91bf1 [BE][Easy][18/19] enforce style for empty lines in import segments in torch/d*/ (#129770)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129770
Approved by: https://github.com/wconstab
2024-08-01 04:22:50 +00:00
Aaron Orenstein
7c12cc7ce4 Flip default value for mypy disallow_untyped_defs [6/11] (#127843)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127843
Approved by: https://github.com/oulgen
ghstack dependencies: #127842
2024-06-08 18:49:29 +00:00
a-r-r-o-w
e08577aec5 Spelling fix (#108490)
Fixes spelling mistake: non-deterinistic -> non-deterministic
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108490
Approved by: https://github.com/ezyang
2023-09-04 16:59:35 +00:00
Edward Z. Yang
3bf922a6ce Apply UFMT to low traffic torch modules (#106249)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106249
Approved by: https://github.com/Skylion007
2023-07-29 23:37:30 +00:00
Xuehai Pan
5b1cedacde [BE] [2/3] Rewrite super() calls in functorch and torch (#94588)
Rewrite Python built-in class `super()` calls. Only non-semantic changes should be applied.

- #94587
- #94588
- #94592

Also, methods with only a `super()` call are removed:

```diff
class MyModule(nn.Module):
-   def __init__(self):
-       super().__init__()
-
    def forward(self, ...):
        ...
```

Some cases that change the semantics should be kept unchanged. E.g.:

f152a79be9/caffe2/python/net_printer.py (L184-L190)

f152a79be9/test/test_jit_fuser_te.py (L2628-L2635)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94588
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-10 21:16:33 +00:00
joncrall
4618371da5 Integrate xdoctest - Rebased (#82797)
This is a new version of #15648 based on the latest master branch.

Unlike the previous PR where I fixed a lot of the doctests in addition to integrating xdoctest, I'm going to reduce the scope here. I'm simply going to integrate xdoctest, and then I'm going to mark all of the failing tests as "SKIP". This will let xdoctest run on the dashboards, provide some value, and still let the dashboards pass. I'll leave fixing the doctests themselves to another PR.

In my initial commit, I do the bare minimum to get something running with failing dashboards. The few tests that I marked as skip are causing segfaults. Running xdoctest results in 293 failed, 201 passed tests. The next commits will be to disable those tests. (unfortunately I don't have a tool that will insert the `#xdoctest: +SKIP` directive over every failing test, so I'm going to do this mostly manually.)

Fixes https://github.com/pytorch/pytorch/issues/71105

@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82797
Approved by: https://github.com/ezyang
2022-08-12 02:08:01 +00:00
anjali411
3bcc19b29a Add __all__ to various submodules in torch.fx, distributions, distributed, package (#80367)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80367
Approved by: https://github.com/albanD
2022-06-27 21:27:30 +00:00
Till Hoffmann
40576bceaf Add mode property to distributions. (#76690)
This PR fixes #69466 and introduces some other minor changes. Tests are somewhat more involved because a reference implementation in `scipy` is not available; tests proceed differently for discrete and continuous distributions.

For continuous distributions, we evaluate the gradient of the `log_prob` at the mode. Tests pass if the gradient is zero OR (the mode is at the boundary of the support of the distribution AND the `log_prob` decreases as we move away from the boundary to the interior of the support).

For discrete distributions, the notion of a gradient is not well defined. We thus "look" ahead and behind one step (e.g. if the mode of a Poisson distribution is 9, we consider 8 and 10). If the step ahead/behind is still within the support of the distribution, we assert that the `log_prob` is smaller than at the mode.

For one-hot encoded distributions (currently just `OneHotCategorical`), we evaluate the underlying mode (i.e. encoded as an integral tensor), "advance" by one label to get another sample that should have lower probability using `other = (mode + 1) % event_size` and re-encode as one-hot. The resultant `other` sample should have lower probability than the mode.

Furthermore, Gamma, half Cauchy, and half normal distributions have their support changed from positive to nonnegative. This change is necessary because the mode of the "half" distributions is zero, and the mode of the gamma distribution is zero for `concentration <= 1`.

cc @fritzo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76690
Approved by: https://github.com/neerajprad
2022-05-11 18:26:56 +00:00
Thomas Viehmann
76b58dd9ae Fix distributions which don't properly honor validate_args=False (#53600)
Summary:
A number of derived distributions use base distributions in their
implementation.

We add what we hope is a comprehensive test whether all distributions
actually honor skipping validation of arguments in log_prob and then
fix the bugs we found. These bugs are particularly cumbersome in
PyTorch 1.8 and master when validate_args is turned on by default
In addition one might argue that validate_args is not performing
as well as it should when the default is not to validate but the
validation is turned on in instantiation.

Arguably, there is another set of bugs or at least inconsistencies
when validation of inputs does not prevent invalid indices in
sample validation (when with validation an IndexError is raised
in the test). We would encourage the implementors to be more
ambitious when validation is turned on and amend sample validation
to throw a ValueError for consistency.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/53600

Reviewed By: mrshenli

Differential Revision: D26928088

Pulled By: neerajprad

fbshipit-source-id: 52784a754da2faee1a922976e2142957c6c02e28
2021-03-10 13:16:32 -08:00
Xu Zhao
146721f1df Fix typing errors in the torch.distributions module (#45689)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/42979.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/45689

Reviewed By: agolynski

Differential Revision: D24229870

Pulled By: xuzhao9

fbshipit-source-id: 5fc87cc428170139962ab65b71cacba494d46130
2020-10-12 10:29:45 -07:00
Neeraj Pradhan
fb40e58f24 Remove deprecated tensor constructors in torch.distributions (#19979)
Summary:
This removes the deprecated `tensor.new_*` constructors (see #16770) from `torch.distributions` module.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19979

Differential Revision: D15195618

Pulled By: soumith

fbshipit-source-id: 46b519bfd32017265e90bd5c53f12cfe4a138021
2019-05-02 20:45:02 -07:00
Neeraj Pradhan
80fa8e1007 Add .expand() method to distribution classes (#11341)
Summary:
This adds a `.expand` method for distributions that is akin to the `torch.Tensor.expand` method for tensors. It returns a new distribution instance with batch dimensions expanded to the desired `batch_shape`. Since this calls `torch.Tensor.expand` on the distribution's parameters, it does not allocate new memory for the expanded distribution instance's parameters.

e.g.
```python
>>> d = dist.Normal(torch.zeros(100, 1), torch.ones(100, 1))
>>> d.sample().shape
  torch.Size([100, 1])
>>> d.expand([100, 10]).sample().shape
  torch.Size([100, 10])
```

We have already been using the `.expand` method in Pyro in our [patch](https://github.com/uber/pyro/blob/dev/pyro/distributions/torch.py#L10) of `torch.distributions`. We use this in our models to enable dynamic broadcasting. This has also been requested by a few users on the distributions slack, and we believe will be useful to the larger community.

Note that currently, there is no convenient and efficient way to expand distribution instances:
 - Many distributions use `TransformedDistribution` (or wrap over another distribution instance. e.g. `OneHotCategorical` uses a `Categorical` instance) under the hood, or have lazy parameters. This makes it difficult to collect all the relevant parameters, broadcast them and construct new instances.
 - In the few cases where this is even possible, the resulting implementation would be inefficient since we will go through a lot of broadcasting and args validation logic in `__init__.py` that can be avoided.

The `.expand` method allows for a safe and efficient way to expand distribution instances. Additionally, this bypasses `__init__.py` (using `__new__` and populating relevant attributes) since we do not need to do any broadcasting or args validation (which was already done when the instance was first created). This can result in significant savings as compared to constructing new instances via `__init__` (that said, the `sample` and `log_prob` methods will probably be the rate determining steps in many applications).

e.g.
```python
>>> a = dist.Bernoulli(torch.ones([10000, 1]), validate_args=True)

>>> %timeit a.expand([10000, 100])
15.2 µs ± 224 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

>>> %timeit dist.Bernoulli(torch.ones([10000, 100]), validate_args=True)
11.8 ms ± 153 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
```

cc. fritzo, apaszke, vishwakftw, alicanb
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11341

Differential Revision: D9728485

Pulled By: soumith

fbshipit-source-id: 3b94c23bc6a43ee704389e6287aa83d1e278d52f
2018-09-11 06:56:18 -07:00
vishwakftw
f940af6293 Bag of Distributions doc fixes (#10894)
Summary:
- Added `__repr__` for Constraints and Transforms.
- Arguments passed to the constructor are now rendered with :attr:

Closes https://github.com/pytorch/pytorch/issues/10884
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10894

Differential Revision: D9514161

Pulled By: apaszke

fbshipit-source-id: 4abf60335d876449f2b6477eb9655afed9d5b80b
2018-08-27 09:55:27 -07:00
Vishwak Srinivasan
3cbaa6b785 [ready] Clean up torch.distributions (#8046) 2018-06-02 16:54:53 +02:00
li-roy
d564ecb4a5 Update docs with new tensor repr (#6454)
* Update docs with new tensor repr

* remove cuda in dtype

* remove changes to gloo submodule

* [docs] document tensor.new_* ctor

* [docs] Add docs for tensor.to(), tensor.float(), etc

* [docs] Moar examples for docs.

* [docs] Warning for tensor ctor copy behavior

* Quick fix

* [docs] Document requires_grad_()

* [docs] Add example for requires_grad_()

* update slogdet and *fft

* update tensor rst

* small fixes

* update some docs

* additional doc changes

* update torch and tensor docs

* finish changing tensor docs

* fix flake8

* slogdet with negative det

* Update functional.py tensor ctors

* Fix nll_loss docs

* reorder to move device up

* torch.LongTensor -> torch.tensor or torch.empty in docs

* update tensor constructors in docs

* change tensor constructors

* change constructors

* change more Tensor() to tensor()

* Show requires_grads_ docs

* Fix set_default_dtype docs

* Update docs with new tensor repr

* remove cuda in dtype

* remove changes to gloo submodule

* [docs] document tensor.new_* ctor

* [docs] Add docs for tensor.to(), tensor.float(), etc

* [docs] Moar examples for docs.

* [docs] Warning for tensor ctor copy behavior

* Quick fix

* [docs] Document requires_grad_()

* [docs] Add example for requires_grad_()

* update slogdet and *fft

* update tensor rst

* small fixes

* update some docs

* additional doc changes

* update torch and tensor docs

* finish changing tensor docs

* fix flake8

* slogdet with negative det

* Update functional.py tensor ctors

* Fix nll_loss docs

* reorder to move device up

* torch.LongTensor -> torch.tensor or torch.empty in docs

* update tensor constructors in docs

* change tensor constructors

* change constructors

* change more Tensor() to tensor()

* Show requires_grads_ docs

* Fix set_default_dtype docs

* Link to torch.no_grad, etc, from torch doc

* Add dtype aliases to table

* regen docs again

* Tensor attributes stub page

* link to inplace sampling

* Link torch.dtype, device, and layout

* fix dots after nonfinite floats

* better layout docs
2018-04-21 07:35:37 -04:00
Fritz Obermeyer
1d51dd8665 [distributions] Fix Independent.rsample() and add more tests (#6806) 2018-04-20 21:55:39 +02:00
gchanan
db53389761
Add numpy.array-like type inference to torch.tensor. (#5997)
* Add numpy.array-like type inference to torch.tensor.

* Temporary fix for int/double types.

* Treat python floats as the default (scalar) dtype.

* Also make 0-length sequences the default scalar type and add more tests.

* Add type inference to sparse_coo_tensor.

* Fix sparse test.

* Remove allow_variables.

* Check numpy platform bits.

* Address review comments.

* Make suggested changes to constraints.

* More checking windows builds.

* Fix test for windows.
2018-03-27 15:27:23 -04:00
Fritz Obermeyer
b2da9fd220 [distributions] Rename .params to .arg_constraints, fix logic (#5989) 2018-03-25 15:24:32 +02:00
lazypanda1
7f864bbe52 Fixed distribution constraints and added some test cases for distributions parameter check (#5358) 2018-03-15 23:11:20 +01:00
Sam Gross
54b4cdeffa
Replace all uses of 'Tensor or Variable' with 'Tensor' (#5508)
Replace all uses of 'Tensor or Variable'  and 'Variable or Tensor' with 'Tensor'
2018-03-02 14:26:11 -05:00
gchanan
da894901ef
Deprecate variable factory, use torch.tensor instead (#5476)
* Remove usages of torch.autograd.variable; use torch.tensor instead.

* Deprecate torch.autograd.variable.

* Remove unused sample_scalar.
2018-03-01 10:58:16 -05:00
Vishwak Srinivasan
85a7e0fc41 Addition of ExponentialFamily (#4876) 2018-02-04 12:18:28 +01:00
Alican Bozkurt
20fbdb9a8b Adding mean, variance, stddev to distributions (#4923) 2018-01-31 00:26:32 +01:00
gchanan
e37f02469d
Favor Variables over Tensors for scalar constructors in torch.distrib… (#4791)
* Favor Variables over Tensors for scalar constructors in torch.distributions.

Current behvior:
1) distribution constructors containing only python number elements will have their python numbers upcasted to Tensors.
2) Python number arguments of distribution constructors that also contain tensors and variables will be upcasted
to the first tensor/variable type.

This PR changes the above to favor Variables as follows:
1) The python numbers will now be upcasted to Variables
2) An error will be raised if the first tensor/variable type is not a Variable.

This is done in preparation for the introduction of Scalars (0-dimensional tensors), which are only available on the Variable API.
Note that we are (separately) merging Variable and Tensor, so this PR should have no real long-term effect.

Also note that the above means we don't change the behavior of constructors without python number arguments.

* Fix tests that require numpy.
2018-01-23 11:49:15 -05:00
Alican Bozkurt
8c2d35c754 Refactor distributions (#4688) 2018-01-17 11:58:08 +01:00
Vishwak Srinivasan
86fe793948 Addition of KL-Divergences for torch.distributions (#4638) 2018-01-14 22:52:28 +01:00
Fritz Obermeyer
a3e91515de Declare constraints for distribution parameters and support (#4450) 2018-01-04 23:58:26 +01:00
Fritz Obermeyer
0bc1505f34 Implement .entropy() methods for all distributions (#4268) 2017-12-20 14:06:01 +01:00
Fritz Obermeyer
ee98e7a82e Implement Dirichlet and Beta distributions (#4117) 2017-12-18 19:11:37 +01:00