Commit Graph

297 Commits

Author SHA1 Message Date
a-r-r-o-w
e08577aec5 Spelling fix (#108490)
Fixes spelling mistake: non-deterinistic -> non-deterministic
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108490
Approved by: https://github.com/ezyang
2023-09-04 16:59:35 +00:00
Aaron Gokaslan
660e8060ad [BE]: Update ruff to 0.285 (#107519)
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.

I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
2023-08-22 23:16:38 +00:00
PyTorch MergeBot
d59a6864fb Revert "[BE]: Update ruff to 0.285 (#107519)"
This reverts commit 88ab3e4322.

Reverted https://github.com/pytorch/pytorch/pull/107519 on behalf of https://github.com/ZainRizvi due to Sorry, but this PR breaks internal tests. @ezyang, can you please hep them get unblocked? It seems like one of the strings was prob accidentally modified ([comment](https://github.com/pytorch/pytorch/pull/107519#issuecomment-1688833480))
2023-08-22 19:53:32 +00:00
Aaron Gokaslan
88ab3e4322 [BE]: Update ruff to 0.285 (#107519)
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.

I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
2023-08-20 01:36:18 +00:00
Austin
45f6ef2597 Expose intended public constraints. Fixes #106386 (#106458)
Fixes #106386

Straightforward change, just exposes the `one_hot` and `nonnegative` distribution constraints that are intended to be public. This fixes downstream pyro usage of these constraints.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106458
Approved by: https://github.com/ezyang, https://github.com/kit1980
2023-08-04 23:20:59 +00:00
Edward Z. Yang
b581e03850 Apply UFMT to torch/distributions/distribution.py, manually resolve fstrings (#106266)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106266
Approved by: https://github.com/Skylion007
2023-07-30 19:10:57 +00:00
Edward Z. Yang
3bf922a6ce Apply UFMT to low traffic torch modules (#106249)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106249
Approved by: https://github.com/Skylion007
2023-07-29 23:37:30 +00:00
Aaron Gokaslan
6d43c89f37 [BE]: Update Ruff to 0.0.280 (#105724)
Removes unusued loop values in python dictionary iteration. Automated fix from Ruff master

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105724
Approved by: https://github.com/ezyang, https://github.com/janeyx99
2023-07-22 23:03:34 +00:00
Justin Chu
4cc1745b13 [BE] f-stringify torch/ and scripts (#105538)
This PR is a follow up on the pyupgrade series to convert more strings to use f-strings using `flynt`.

- https://docs.python.org/3/reference/lexical_analysis.html#f-strings
- https://pypi.org/project/flynt/

Command used:

```
flynt torch/ -ll 120
flynt scripts/ -ll 120
flynt tools/ -ll 120
```

and excluded `collect_env.py`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105538
Approved by: https://github.com/ezyang, https://github.com/malfet
2023-07-21 19:35:24 +00:00
Justin Chu
3721fa5612 [BE] Enable ruff's UP rules and autoformat optim/ (#105426)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105426
Approved by: https://github.com/malfet, https://github.com/albanD, https://github.com/aaronenyeshi, https://github.com/janeyx99
2023-07-18 21:07:43 +00:00
Nikita Shulga
5837e95d30 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`

Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-15 20:30:20 +00:00
PyTorch MergeBot
15fd1ea118 Revert "[Reland] Update mypy to 1.4.1 (#105227)"
This reverts commit c9c4f8efc3.

Reverted https://github.com/pytorch/pytorch/pull/105227 on behalf of https://github.com/atalman due to trying to mitigate ci sev #105248 ([comment](https://github.com/pytorch/pytorch/pull/105227#issuecomment-1636510935))
2023-07-14 22:28:35 +00:00
Nikita Shulga
c9c4f8efc3 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-14 20:45:12 +00:00
PyTorch MergeBot
b4d91b1c5b Revert "[Typing] Fix PEP 484 Violation (#105022)"
This reverts commit 4148b7bada.

Reverted https://github.com/pytorch/pytorch/pull/105022 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/105022#issuecomment-1635967734))
2023-07-14 14:45:09 +00:00
Nikita Shulga
4148b7bada [Typing] Fix PEP 484 Violation (#105022)
Not sure, how it worked before, but if arguments must be annotated is optional if they are defaulted to None

Towards enabling mypy-1.4.1 in lintrunner

<!--
copilot:poem
-->
### <samp>🤖 Generated by Copilot at 5e1b9f4</samp>

> _We annotate the arguments of doom_
> _To show the `None` values of gloom_
> _We improve the type checking and readability_
> _With `Optional` annotations of metal-ity_

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105022
Approved by: https://github.com/izaitsevfb, https://github.com/huydhn, https://github.com/Skylion007
2023-07-12 10:20:48 +00:00
Kale Kundert
e75f7994e1 Fix Dirichlet.log_prob() when x=0 and alpha=1 (#103605)
`Dirichlet.log_prob()` incorrectly returns NaN in the case where $x_i=0$ and $\alpha_i=1$.  The Dirichlet PDF is given by:
$$\frac{1}{B(\alpha)} \prod_{i=1}^{K} x_i^{\alpha_i - 1}$$
So this corresponds to the case where one of the terms has the form $0^0=1$. The logarithm of such a term should be 0, but you get NaN if you try to calculate it as `0 * log(0)`.

This PR implements the same algorithm that `scipy.stats.dirichlet` uses to avoid this behavior, namely `xlogy(alpha - 1, x)` instead of `(alpha - 1) * log(x)`.  It also adds a test case comparing the pytorch and scipy implementations for this specific case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103605
Approved by: https://github.com/albanD
2023-06-15 16:16:50 +00:00
Matthew Hoffman
29da75cc55 Enable mypy allow redefinition (#102046)
Related #101528

I tried to enable this in another PR but it uncovered a bunch of type errors: https://github.com/pytorch/pytorch/actions/runs/4999748262/jobs/8956555243?pr=101528#step:10:1305

The goal of this PR is to fix these errors.

---

This PR enables [allow_redefinition = True](https://mypy.readthedocs.io/en/stable/config_file.html#confval-allow_redefinition) in `mypy.ini`, which allows for a common pattern:

> Allows variables to be redefined with an arbitrary type, as long as the redefinition is in the same block and nesting level as the original definition.

`allow_redefinition` allows mypy to be more flexible by allowing reassignment to an existing variable with a different type... for instance (from the linked PR):

4a1e9230ba/torch/nn/parallel/data_parallel.py (L213)

A `Sequence[Union[int, torch.device]]` is narrowed to `Sequence[int]` thru reassignment to the same variable.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102046
Approved by: https://github.com/ezyang
2023-05-24 07:05:30 +00:00
Sergii Dymchenko
e17d9f2c64 Fix determenistic typos (#101631)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101631
Approved by: https://github.com/lezcano, https://github.com/ZainRizvi
2023-05-17 16:12:28 +00:00
Alexis Thual
24cc7fe020 Fix Wishart distribution documentation (#95816)
This PR fixes the `torch.distributions.wishart.Wishart` example.

Running the current example
```python
m = Wishart(torch.eye(2), torch.Tensor([2]))
m.sample()  # Wishart distributed with mean=`df * I` and
            # variance(x_ij)=`df` for i != j and variance(x_ij)=`2 * df` for i == j
```
fails with
```
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Untitled-1 in
      [321](untitled:Untitled-1?line=320) # %%
----> [322](untitled:Untitled-1?line=321) m = Wishart(torch.eye(2), torch.Tensor([2]))
      [323](untitled:Untitled-1?line=322) m.sample()  # Wishart distributed with mean=`df * I` and
      [324](untitled:Untitled-1?line=323)             # variance(x_ij)=`df` for i != j and variance(x_ij)=`2 * df` for i == j

Untitled-1 in __init__(self, df, covariance_matrix, precision_matrix, scale_tril, validate_args)
     [83](untitled:Untitled-1?line=82)
     [84](untitled:Untitled-1?line=83)         if param.dim() < 2:
---> [85](untitled:Untitled-1?line=84)             raise ValueError("scale_tril must be at least two-dimensional, with optional leading batch dimensions")
     [86](untitled:Untitled-1?line=85)
     [87](untitled:Untitled-1?line=86)         if isinstance(df, Number):

ValueError: scale_tril must be at least two-dimensional, with optional leading batch dimensions
```

Is seems that the parameters of `Wishart.__init__()` were re-ordered, but the documentation was not updated.
This PR fixes it. Here is the updated behaviour:

```python
m = Wishart(torch.Tensor([2]), covariance_matrix=torch.eye(2))
m.sample()
```

```
Untitled-1:255: UserWarning: Singular sample detected.
tensor([[[6.6366, 0.7796],
         [0.7796, 0.2136]]])
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95816
Approved by: https://github.com/ngimel, https://github.com/kit1980
2023-05-16 02:07:30 +00:00
KuangDW
07d3772eff fix typo in comments under torch/distributions/mixture_same_family.py (#101290)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101290
Approved by: https://github.com/Skylion007
2023-05-13 14:25:52 +00:00
Aaron Gokaslan
738ba13b35 [BE]: enable PLE error codes in ruff and fix bugs (#101079)
Enables PyLint error codes implemented in ruff. These are un-opinionated static analysis checks on Python code that finds common bugs. After running all the PLE error codes that are implemented in ruff, I fixed the bugs, added a few ignores for malformed Python code that is part of our JIT test script, and finally added a few ignores for a false positive on PLE0605 and submitted an issue upstream to fix in ruff https://github.com/charliermarsh/ruff/issues/4345 .

Common bugs found here include analysis for malformed logging format calls, bad string format calls, invalid escape sequences, and more.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101079
Approved by: https://github.com/malfet
2023-05-11 23:57:25 +00:00
gui11aume
7ec4392068 Remove in-place operations in NegativeBinomial (#96748)
This is a suggestion for a minor modification.

The line `log_normalization[self.total_count + value == 0.] = 0.` prevents Jit compilation when the condition occurs, with the error message

`RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation.`

I propose an alternative that does not involve in-place operations. It uses the function `nan_to_num()` to replace infinite values by 0 where `self.total_count + value == 0.` while leaving `nan` and `-inf` as they are. Readability is suboptimal because the code does not replace nan with numbers, but I could not find a function that only replaces infinite values.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96748
Approved by: https://github.com/fritzo, https://github.com/soulitzer
2023-04-26 14:45:08 +00:00
Aaron Gokaslan
597b558c51 [BE]: Update flake8 and plugins and fix bugs (#97795)
Update flake8 and flake8-plugins in lintrunner to a modern version. Enables more checks and makes flake8 checks significantly faster. Added a few additional rule ignores that will need to be fixed in the future.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97795
Approved by: https://github.com/alexsio27444, https://github.com/janeyx99, https://github.com/ezyang
2023-03-28 23:51:55 +00:00
Vladimir S. FONOV
b0b5f3c6c6 Fix gumbel cdf (#91698)
Fix `Gumbel.cdf` function.

**Description**
When transformed parameters is outside of the support of underlying Uniform distribution. This makes behavior of `Gumbel.cdf` consistent with other `TransformedDistribution` that pass value of validate_args to the base distribution.

**Issue**
running `Gumbel(0.0,1.0,validate_args=False).cdf(20.0)` would cause `ValueError` exception from `_validate_sample`

**Testing**
Test was added to the `test_distributions.py` to check if `Gumbel(0.0,1.0,validate_args=False).cdf(20.0)` successfully returns `1.0`

This is a second attempt to push changes , after https://github.com/pytorch/pytorch/pull/82488

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91698
Approved by: https://github.com/fritzo, https://github.com/zou3519
2023-03-07 23:04:47 +00:00
Xuehai Pan
b005ec62b9 [BE] Remove dependency on six and future (#94709)
Remove the Python 2 and 3 compatibility library [six](https://pypi.org/project/six) and [future](https://pypi.org/project/future) and `torch._six`. We only support Python 3.8+ now. It's time to retire them.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94709
Approved by: https://github.com/malfet, https://github.com/Skylion007
2023-02-14 09:14:14 +00:00
Xuehai Pan
5b1cedacde [BE] [2/3] Rewrite super() calls in functorch and torch (#94588)
Rewrite Python built-in class `super()` calls. Only non-semantic changes should be applied.

- #94587
- #94588
- #94592

Also, methods with only a `super()` call are removed:

```diff
class MyModule(nn.Module):
-   def __init__(self):
-       super().__init__()
-
    def forward(self, ...):
        ...
```

Some cases that change the semantics should be kept unchanged. E.g.:

f152a79be9/caffe2/python/net_printer.py (L184-L190)

f152a79be9/test/test_jit_fuser_te.py (L2628-L2635)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94588
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-10 21:16:33 +00:00
Aaron Gokaslan
8fce9a09cd [BE]: pyupgrade Python to 3.8 - imports and object inheritance only (#94308)
Apply parts of pyupgrade to torch (starting with the safest changes).
This PR only does two things: removes the need to inherit from object and removes unused future imports.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94308
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-07 21:10:56 +00:00
Yanbo Liang
0ab4ab9f8d [Dynamo] Fix calling UserDefinedObject.func should pass self object (#92050)
Fixes #90834

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92050
Approved by: https://github.com/jansel
2023-01-21 05:47:01 +00:00
Peter Bell
206f4e47bb Replace exp(x) - 1 with expm1(x) (#92154)
This offers improved precision near zero where `exp(x)` is `1 + O(x)` and doing
`(1 + O(x)) - 1` will truncate anything below the float epsilon to zero.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92154
Approved by: https://github.com/lezcano
2023-01-18 10:43:57 +00:00
Peter Bell
4058dedf21 Replace log(1 + x) with log1p(x) (#92114)
`log1p` offers better precision near zero since `(1 + x) - 1` truncates any
values less than the float epsilon to zero. For `soft_margin_loss` this also
requires one fewer kernel invocation which for numel=1e7 gives me a 1.2x speedup
on CUDA and a 1.1x speedup on CPU.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92114
Approved by: https://github.com/ngimel, https://github.com/lezcano
2023-01-18 10:43:56 +00:00
joncrall
ad782ff7df Enable xdoctest runner in CI for real this time (#83816)
Builds on #83317 and enables running the doctests. Just need to figure out what is causing the failures.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83816
Approved by: https://github.com/ezyang, https://github.com/malfet
2022-12-29 05:32:42 +00:00
Felix Divo
7fecba7bdb Doc improvement in LKJCholesky distribution (#91091)
Better structure & formatting. Added more info to reference.

The change can be viewed here: https://docs-preview.pytorch.org/91091/distributions.html?highlight=lkjcholesky#torch.distributions.lkj_cholesky.LKJCholesky
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91091
Approved by: https://github.com/kit1980
2022-12-20 23:38:57 +00:00
ecao
eae0f3f5e3 Add mkl implementation for exponential on CPU (#69967)
### Description
Add mkl implementation for exponential on CPU to improve the performance of exponential.

### Testing
data type: float32
single socket (28cores):
```
before: torch.Size([10, 128, 10, 124])  0.065 s
        torch.Size([10, 128, 20, 124])  0.130 s

after:  torch.Size([10, 128, 10, 124])  5.9e-05 s
        torch.Size([10, 128, 20, 124])  0.000113 s
```
single core:
```
before: torch.Size([10, 128, 10, 124])  0.065 s
        torch.Size([10, 128, 20, 124])  0.130 s

after:  torch.Size([10, 128, 10, 124])  0.00117 s
        torch.Size([10, 128, 20, 124])  0.002347 s
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/69967
Approved by: https://github.com/frank-wei, https://github.com/jgong5
2022-12-13 09:51:24 +00:00
Till Hoffmann
b485781440 Add a transform for positive-definite matrices. (#76777)
The `PositiveDefiniteTransform` is required to transform from an unconstrained space to positive definite matrices, e.g. to support testing the Wishart mode in #76690. It is a simple extension of the `LowerCholeskyTransform`.

I've also added a small test that ensures the generated data belong to the domain of the associated transform. Previously, the data generated for the inverse transform of the `LowerCholeskyTransform` wasn't part of the domain, and the test only passed because the comparison uses `equal_nan=True`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76777
Approved by: https://github.com/lezcano, https://github.com/fritzo, https://github.com/soumith
2022-12-08 09:18:44 +00:00
Edward Z. Yang
a43e09c064 Implement gamma cdf (#89955)
Authored by tillahoffmann originally at https://github.com/pytorch/pytorch/pull/72518

Implements the cumulative distribution function for the gamma distribution. The tests needed a small adjustment to pass because gradients cannot be evaluated with respect to the first argument of the incomplete gamma function (and they're not needed for the test).

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89955
Approved by: https://github.com/wconstab, https://github.com/malfet
2022-12-01 00:12:53 +00:00
Kazuaki Ishizaki
1cd6ebe095 Fix typos in messages under torch (#89049)
This PR fixes typos of messages in `.py` files under torch directory.
Only in `torch/onnx/symbolic_opset16.py`, fix a typo in comment to make the operator name correct.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89049
Approved by: https://github.com/lezcano
2022-11-17 04:18:14 +00:00
Johannes Pitz
8ebbd5a89a Easier to understand event_dim computation (#81396)
Fixes #81254
Only easier to understand, not a real fix.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81396
Approved by: https://github.com/fritzo, https://github.com/kit1980
2022-11-16 04:38:32 +00:00
Kazuaki Ishizaki
4ea2310f1e Fix typos used in documents under torch directory (#88483)
This PR fixes typos, in comments of Python files, that are found from a search box at https://pytorch.org/docs/master/search.html.
This is a follow-up of #88300.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88483
Approved by: https://github.com/kit1980
2022-11-08 01:33:36 +00:00
Ethan Pronovost
585d71513d Add type annotations to distribution.py (#87577)
As title.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87577
Approved by: https://github.com/kit1980
2022-10-26 18:50:48 +00:00
jimku9
32152ce328 Add original sources/references to Wishart.py in distributions (#86543)
@fritzo As discussed, add original sources/references to Wishart.py in distributions and corrected typos in the error messages.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86543
Approved by: https://github.com/fritzo
2022-10-11 21:21:53 +00:00
anjali411
e2a4dfa468 Add correct __all__ for torch.distributed and torch.cuda submodules (#85702)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85702
Approved by: https://github.com/ezyang, https://github.com/albanD, https://github.com/rohan-varma
2022-10-10 19:15:24 +00:00
anjali411
cf2f552cd8 Add __all__ to torch.{fx, distributed, backends} submodules (#85079)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85079
Approved by: https://github.com/rohan-varma
2022-09-20 12:51:08 +00:00
Horace He
8843f5b986 remove data-dependent shapes from some distributions (#84322)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84322
Approved by: https://github.com/voznesenskym
2022-08-31 09:37:55 +00:00
joncrall
b136f3f310 More doctest refinements. (#83317)
Follow up to #82797

Now that the doctests themselves are in a better state, we should be able to enable xdoctest on the CI so they stay that way.

@ezyang @vadimkantorov
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83317
Approved by: https://github.com/ezyang
2022-08-22 20:07:26 +00:00
joncrall
4618371da5 Integrate xdoctest - Rebased (#82797)
This is a new version of #15648 based on the latest master branch.

Unlike the previous PR where I fixed a lot of the doctests in addition to integrating xdoctest, I'm going to reduce the scope here. I'm simply going to integrate xdoctest, and then I'm going to mark all of the failing tests as "SKIP". This will let xdoctest run on the dashboards, provide some value, and still let the dashboards pass. I'll leave fixing the doctests themselves to another PR.

In my initial commit, I do the bare minimum to get something running with failing dashboards. The few tests that I marked as skip are causing segfaults. Running xdoctest results in 293 failed, 201 passed tests. The next commits will be to disable those tests. (unfortunately I don't have a tool that will insert the `#xdoctest: +SKIP` directive over every failing test, so I'm going to do this mostly manually.)

Fixes https://github.com/pytorch/pytorch/issues/71105

@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82797
Approved by: https://github.com/ezyang
2022-08-12 02:08:01 +00:00
ProGamerGov
71d50f4f89 Change docstring type callable to Callable for consistency (#82487)
### Description

Across PyTorch's docstrings, both `callable` and `Callable` for variable types. The Callable should be capitalized as we are referring to the `Callable` type, and not the Python `callable()` function.

### Testing

There shouldn't be any testing required.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82487
Approved by: https://github.com/albanD
2022-08-01 17:26:09 +00:00
Feynman Liang
40feeea500 Fix typo in dirichlet.py example (#82062)
### Description
<!-- What did you change and why was it needed? -->

### Issue
<!-- Link to Issue ticket or RFP -->

### Testing
<!-- How did you test your change? -->

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82062
Approved by: https://github.com/kit1980
2022-07-23 22:30:12 +00:00
Till Hoffmann
bf6481553a Ensure Transform is pickleable. (#81707)
`Transform` is not currently pickleable if the inverse transform cache `_inv` is not `None` because `_inv` is a `weakref` which cannot be serialized by `pickle`.

The following succeeds.

```python
>>> import torch as th
>>> import pickle

>>> dist = th.distributions.TransformedDistribution(
...     th.distributions.Normal(0, 1),
...     [th.distributions.AffineTransform(2, 3)]
... )
>>> th.save(dist, "some-file.pt")
```

But the transformed distribution can no longer be pickled after evaluating `log_prob` (which implicitly creates `_inv`).

```python
>>> dist.log_prob(th.linspace(0, 1, 10))
>>> th.save(dist, "some-file.pt")
TypeError: cannot pickle 'weakref' object
```

This PR fixes the issue by setting `_inv` to `None` in `__getstate__`. cc @fritzo, @neerajprad
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81707
Approved by: https://github.com/fritzo
2022-07-22 06:33:53 +00:00
MohammadReza Ebrahimi
29a9928767 torch.distribution examples rendering issue (#81611)
# Issue

"Example" section in the documentation is not rendering correctly for [`torch.distributions.transforms.CatTransform`](https://pytorch.org/docs/stable/distributions.html#torch.distributions.transforms.CatTransform)
and
[`torch.distributions.transforms.StackTransform`](https://pytorch.org/docs/stable/distributions.html#torch.distributions.transforms.StackTransform)

# Fix
Simply add an empty line after the `Example::` keyword to fix the issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81611
Approved by: https://github.com/kit1980
2022-07-19 01:06:37 +00:00
Nicola De Cao
27fc9fcd13 More stable computation of KL between two Bernoulli distributions (#79944)
Fixes #20164

@neerajprad here the new PR with the updated master

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79944
Approved by: https://github.com/neerajprad
2022-06-27 21:31:45 +00:00