Commit Graph

32 Commits

Author SHA1 Message Date
Xu Zhao
146721f1df Fix typing errors in the torch.distributions module (#45689)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/42979.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/45689

Reviewed By: agolynski

Differential Revision: D24229870

Pulled By: xuzhao9

fbshipit-source-id: 5fc87cc428170139962ab65b71cacba494d46130
2020-10-12 10:29:45 -07:00
Ralf Gommers
726aa713d5 Replace torch.is_tensor usages with isinstance checks. (#38062)
Summary:
`is_tensor` doesn't really have a reason to exist anymore (other than
backwards compatibility) and is worse for typechecking with mypy (see
gh-32824). Given that it may not be obvious what the fix is once mypy
gives an error, make the change in a number of places at once, and add
a note on this to the `is_tensor` docstring.

Recommending an isinstance check instead has been done for quite a
while, e.g. https://github.com/pytorch/pytorch/pull/7769#discussion_r190458971
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38062

Differential Revision: D21470963

Pulled By: ezyang

fbshipit-source-id: 98dd60d32ca0650abd2de21910b541d32b0eea41
2020-05-08 10:10:11 -07:00
Neeraj Pradhan
fb40e58f24 Remove deprecated tensor constructors in torch.distributions (#19979)
Summary:
This removes the deprecated `tensor.new_*` constructors (see #16770) from `torch.distributions` module.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19979

Differential Revision: D15195618

Pulled By: soumith

fbshipit-source-id: 46b519bfd32017265e90bd5c53f12cfe4a138021
2019-05-02 20:45:02 -07:00
Edward Yang
173f224570 Turn on F401: Unused import warning. (#18598)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18598
ghimport-source-id: c74597e5e7437e94a43c163cee0639b20d0d0c6a

Stack from [ghstack](https://github.com/ezyang/ghstack):
* **#18598 Turn on F401: Unused import warning.**

This was requested by someone at Facebook; this lint is turned
on for Facebook by default.  "Sure, why not."

I had to noqa a number of imports in __init__.  Hypothetically
we're supposed to use __all__ in this case, but I was too lazy
to fix it.  Left for future work.

Be careful!  flake8-2 and flake8-3 behave differently with
respect to import resolution for # type: comments.  flake8-3 will
report an import unused; flake8-2 will not.  For now, I just
noqa'd all these sites.

All the changes were done by hand.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Differential Revision: D14687478

fbshipit-source-id: 30d532381e914091aadfa0d2a5a89404819663e3
2019-03-30 09:01:17 -07:00
vishwakftw
b4c3268b23 Batched upper triangular, lower triangular (#15257)
Summary:
Changelog:

- Implements `triu` and `tril` for batches of 2D tensors.
- Remove TH/THC binding for `tril`
- Fix CUDA implementation
- Update docstrings for tril and triu.
- Remove mask-based `triu` and `tril` in cholesky forward and backward.
- Remove batched tril in torch.distributions.utils
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15257

Differential Revision: D13613888

Pulled By: mrshenli

fbshipit-source-id: 0949a05b9b8e974c1acfaf02a6284848ec5cc1c4
2019-01-09 19:46:39 -08:00
vishwakftw
6911ce19d7 Remove _finfo; replace _finfo usage with torch.finfo (#15165)
Summary:
This PR removes the usage of _finfo defined in torch.distributions.utils and changes the call sites
to use torch.finfo instead

Differential Revision: D13451936

Pulled By: soumith

fbshipit-source-id: 6dbda3a6179d9407bc3396bf1a2baf3e85bc4cf2
2018-12-13 14:30:27 -08:00
Fritz Obermeyer
2431eac7c0 Ensure most Distribution methods are jittable (#11560)
Summary:
This adds tests in tests/test_distributions.py to ensure that all methods of `Distribution` objects are jittable.

I've replaced a few samplers with jittable versions:
- `.uniform_()` -> `torch.rand()`
- `.exponential_()` -> `-(-torch.rand()).log1p()`
- `.normal_()` -> `torch.normal(torch.zeros(...), torch.ones(...), ...)`

Some jit failures remain, and are marked in test_distributions.py
- `Cauchy` and `HalfCauchy` do not support sampling due to missing `.cauchy_()`
- `Binomial` does not support `.enumerate_support()` due to `arange` ignoring its first arg.
- `MultivariateNormal`, `LowRankMultivariateNormal` do not support `.mean`, `.entropy`

- [x] Currently some tests fail (I've skipped those) due to unavailability of `aten::uniform` and `aten::cauchy` in the jit. Can someone suggest how to add these? I tried to add declarations to `torch/csrc/ir.cpp` and `torch/csrc/passes/shape_analysis.cpp`, but that resulted in "Couldn't find operator" errors.
- [x] There are still lots of `TracerWarning`s that something doesn't match something. I'm not sure whether these are real.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11560

Differential Revision: D9816327

Pulled By: apaszke

fbshipit-source-id: 72ec998ea13fc4c76d1ed003d9502e0fbaf728b8
2018-09-13 19:55:01 -07:00
Richard Zou
6b338c8026 Implement torch.broadcast_tensors (#10075)
Summary:
This exposes expand_outplace to python. Fixes #8076. Fixes #10041.

I didn't name it torch.broadcast because numpy.broadcast does something
slightly different (it returns an object with the correct shape
information).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10075

Differential Revision: D9125816

Pulled By: zou3519

fbshipit-source-id: ebe17c8bb54a73ec84b8f76ce14aff3e9c56f4d1
2018-08-01 19:18:34 -07:00
Alican Bozkurt
8766daeec9 Refactor _log_sum_exp (#9173)
Summary:
This PR removes `distributions.utils._log_sum_exp` in favor of `torch.logsumexp`. Also fixes some warnings with `reduce` arg. in `binary_cross_entropy_with_logits`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9173

Reviewed By: SsnL

Differential Revision: D8764174

Pulled By: ezyang

fbshipit-source-id: b9c4136dbf0182e8ae77082e6448d23a430d5cb6
2018-07-15 17:40:53 -07:00
Wei Yang
cb1bfe91af Deprecated several functions at torch.nn.functional (#8748)
Summary:
1. fixes #6245
2. deprecated tanh, sigmoid
Closes https://github.com/pytorch/pytorch/pull/8748

Differential Revision: D8697975

Pulled By: weiyangfb

fbshipit-source-id: f30714aa0611a1fe870040692f3dbcc8238aece9
2018-07-02 15:54:46 -07:00
Vishwak Srinivasan
61f61de270 Expose logsumexp docs and mark log_sum_exp in distributions for internal use (#8428) 2018-06-13 12:27:58 -04:00
Vishwak Srinivasan
3cbaa6b785 [ready] Clean up torch.distributions (#8046) 2018-06-02 16:54:53 +02:00
Tongzhou Wang
c946db16ec
[distributions] Always enable grad when calculating lazy_property (#7708)
* Always enable grad when calculating lazy_property

* Add test with MultiVariableNormal
2018-05-24 11:22:39 -04:00
Tongzhou Wang
1c01eabd3c
Codemod to update our codebase to 0.4 standard (#6641)
* Codemod to update our codebase to 0.4 standard

* Update some of the test scri[ts

* remove Variable in test_clip_grad_value

* fix _symbolic_override_wrapper_maker
2018-04-17 22:06:54 -04:00
gchanan
db53389761
Add numpy.array-like type inference to torch.tensor. (#5997)
* Add numpy.array-like type inference to torch.tensor.

* Temporary fix for int/double types.

* Treat python floats as the default (scalar) dtype.

* Also make 0-length sequences the default scalar type and add more tests.

* Add type inference to sparse_coo_tensor.

* Fix sparse test.

* Remove allow_variables.

* Check numpy platform bits.

* Address review comments.

* Make suggested changes to constraints.

* More checking windows builds.

* Fix test for windows.
2018-03-27 15:27:23 -04:00
Brooks
1936753708 Added an implementation of a multivariate normal distribution (#4950) 2018-03-19 23:22:46 +01:00
Sam Gross
54b4cdeffa
Replace all uses of 'Tensor or Variable' with 'Tensor' (#5508)
Replace all uses of 'Tensor or Variable'  and 'Variable or Tensor' with 'Tensor'
2018-03-02 14:26:11 -05:00
Sam Gross
70ba50c3d4 Remove some uses of torch.is_tensor in favor of isinstance (#5473) 2018-03-02 06:17:38 -05:00
gchanan
da894901ef
Deprecate variable factory, use torch.tensor instead (#5476)
* Remove usages of torch.autograd.variable; use torch.tensor instead.

* Deprecate torch.autograd.variable.

* Remove unused sample_scalar.
2018-03-01 10:58:16 -05:00
gchanan
d5038309a1
Remove WITH_SCALARS, as it's enabled by default now. (#5437) 2018-02-27 14:51:11 -05:00
Tongzhou Wang
47ee86776e Fix CPU torch.multinomial with noncontiguous prob tensor (#5093)
* fix CPU torch.multinomial not working on noncontiguous probability distn'

* address comments

* change some tabs to spaces in THStorage.c
2018-02-06 22:11:43 -05:00
Fritz Obermeyer
ca5071d072 Support multivariate TransformedDistributions (#4937) 2018-01-31 18:32:24 +01:00
gchanan
4970e73304
Add support for distributions and test_distributions when WITH_SCALAR… (#4834)
* Add support for distributions and test_distributions when WITH_SCALARS enabled.

* Fix flake8.
2018-01-24 19:22:05 -05:00
gchanan
e37f02469d
Favor Variables over Tensors for scalar constructors in torch.distrib… (#4791)
* Favor Variables over Tensors for scalar constructors in torch.distributions.

Current behvior:
1) distribution constructors containing only python number elements will have their python numbers upcasted to Tensors.
2) Python number arguments of distribution constructors that also contain tensors and variables will be upcasted
to the first tensor/variable type.

This PR changes the above to favor Variables as follows:
1) The python numbers will now be upcasted to Variables
2) An error will be raised if the first tensor/variable type is not a Variable.

This is done in preparation for the introduction of Scalars (0-dimensional tensors), which are only available on the Variable API.
Note that we are (separately) merging Variable and Tensor, so this PR should have no real long-term effect.

Also note that the above means we don't change the behavior of constructors without python number arguments.

* Fix tests that require numpy.
2018-01-23 11:49:15 -05:00
Alican Bozkurt
3254eca8c8 Implement binomial distribution (#4658) 2018-01-16 21:39:05 +01:00
Fritz Obermeyer
8cff8e93d2 Add torch.distributions.utils._finfo for numerical stability (#4572)
* Add torch.distributions.utils.finfo

* Make _finfo private

* Address review comments

* Simplify _finfo() to key on Storage type
2018-01-10 21:42:47 -05:00
Neeraj Pradhan
408c84de7c Supporting logits as parameters in Bernoulli and Categorical (#4448)
* Supporting logits as parameters in Bernoulli and Categorical

* address comments

* fix lint

* modify binary_cross_entropy_with_logits

* address comments

* add descriptor for lazy attributes

* address comments
2018-01-05 03:45:05 -05:00
Fritz Obermeyer
35abc4efa2 Add low-precision digamma() and polygamma() functions (#4399) 2018-01-02 11:53:23 +01:00
Fritz Obermeyer
0bc1505f34 Implement .entropy() methods for all distributions (#4268) 2017-12-20 14:06:01 +01:00
Neeraj Pradhan
fac711c238 Provide full support for distribution shapes (#4193) 2017-12-15 12:41:08 +01:00
Neeraj Pradhan
4f4e0df68f Allow for broadcasting of distribution parameters (#4140) 2017-12-14 09:37:03 +01:00
Neeraj Pradhan
ba93c031f2 Moving distribution classes into a separate package 2017-12-12 02:44:44 -08:00