Commit Graph

50 Commits

Author SHA1 Message Date
Xuehai Pan
a10b765bf1 [pytree] add APIs to determine a class is a namedtuple or PyStructSequence (#113257)
Changes in this PR:

1. Add `is_structseq` and `is_structseq_class` functions to determine a object or a class is PyStructSequence.
2. Add a generic class `structseq` which can be used as the registration key for PyStructSequence types like `namedtuple` for Named Tuple types.
3. Change `is_namedtuple` to accept subclasses of namedtuple to be namedtuple. Before this PR, only namedtuple class directly created by `collections.namedtuple` or `typing.NamedTuple` were namedtuple classes while their subclasses were not. This PR makes `is_namedtuple` return true for subclasses of namedtuple class.

Resolves #75982. New tests are included in this PR.

- #75982

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113257
Approved by: https://github.com/zou3519
2025-04-01 10:40:43 +00:00
PyTorch MergeBot
f9b4856989 Revert "[pytree] add APIs to determine a class is a namedtuple or PyStructSequence (#113257)"
This reverts commit c95a6b416b.

Reverted https://github.com/pytorch/pytorch/pull/113257 on behalf of https://github.com/ZainRizvi due to Sorry but this is breaking internally. @zou3519 can you please help land this internally? See the sigmoid tests in D71198793 for details. To validate the fixes internally, you can follow the instructions here: https://fburl.com/fixing-ghfirst-reverts ([comment](https://github.com/pytorch/pytorch/pull/113257#issuecomment-2725982539))
2025-03-14 23:13:34 +00:00
Xuehai Pan
c95a6b416b [pytree] add APIs to determine a class is a namedtuple or PyStructSequence (#113257)
Changes in this PR:

1. Add `is_structseq` and `is_structseq_class` functions to determine a object or a class is PyStructSequence.
2. Add a generic class `structseq` which can be used as the registration key for PyStructSequence types like `namedtuple` for Named Tuple types.
3. Change `is_namedtuple` to accept subclasses of namedtuple to be namedtuple. Before this PR, only namedtuple class directly created by `collections.namedtuple` or `typing.NamedTuple` were namedtuple classes while their subclasses were not. This PR makes `is_namedtuple` return true for subclasses of namedtuple class.

Resolves #75982. New tests are included in this PR.

- #75982

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113257
Approved by: https://github.com/zou3519
2025-03-14 08:50:30 +00:00
PyTorch MergeBot
ebd087e4b5 Revert "[pytree] add APIs to determine a class is a namedtuple or PyStructSequence (#113257)"
This reverts commit f08146b67b.

Reverted https://github.com/pytorch/pytorch/pull/113257 on behalf of https://github.com/jovianjaison due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/113257#issuecomment-2711299830))
2025-03-10 17:19:21 +00:00
Xuehai Pan
f08146b67b [pytree] add APIs to determine a class is a namedtuple or PyStructSequence (#113257)
Changes in this PR:

1. Add `is_structseq` and `is_structseq_class` functions to determine a object or a class is PyStructSequence.
2. Add a generic class `structseq` which can be used as the registration key for PyStructSequence types like `namedtuple` for Named Tuple types.
3. Change `is_namedtuple` to accept subclasses of namedtuple to be namedtuple. Before this PR, only namedtuple class directly created by `collections.namedtuple` or `typing.NamedTuple` were namedtuple classes while their subclasses were not. This PR makes `is_namedtuple` return true for subclasses of namedtuple class.

Resolves #75982. New tests are included in this PR.

- #75982

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113257
Approved by: https://github.com/zou3519
2025-03-06 18:59:02 +00:00
Tom Ritchford
c0582fd0f8 Remove unused Python variables in torch/[b-z]* (#136963)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136963
Approved by: https://github.com/ezyang
2024-10-19 16:45:22 +00:00
Benjamin Glass
f984b88718 Ensure noncontiguous tensor creation tests offsetting (#136396)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136396
Approved by: https://github.com/amjames, https://github.com/eellison
ghstack dependencies: #136055
2024-10-02 00:40:43 +00:00
albanD
6791b0c09e Change default torch_function behavior to be disabled when torch_dispatch is defined (take 2) (#120632)
This does not introduce a new test but is tested by checking that all the classes we already have still behave as before now that they don't explicitly disable torch_function.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120632
Approved by: https://github.com/ezyang
2024-03-09 01:08:37 +00:00
Edward Z. Yang
9bce208dfb Replace follow_imports = silent with normal (#118414)
This is a lot of files changed! Don't panic! Here's how it works:

* Previously, we set `follow_imports = silent` for our mypy.ini configuration. Per https://mypy.readthedocs.io/en/stable/running_mypy.html#follow-imports, what this does is whenever we have an import to a module which is not listed as a file to be typechecked in mypy, we typecheck it as normal but suppress all errors that occurred in that file.
* When mypy is run inside lintrunner, the list of files is precisely the files covered by the glob in lintrunner.toml, but with files in excludes excluded.
* The top-level directive `# mypy: ignore-errors` instructs mypy to typecheck the file as normal, but ignore all errors.
* Therefore, it should be equivalent to set `follow_imports = normal`, if we put `# mypy: ignore-errors` on all files that were previously excluded from the file list.
* Having done this, we can remove the exclude list from .lintrunner.toml, since excluding a file from typechecking is baked into the files themselves.
* torch/_dynamo and torch/_inductor were previously in the exclude list, because they were covered by MYPYINDUCTOR. It is not OK to mark these as `# mypy: ignore-errors` as this will impede typechecking on the alternate configuration. So they are temporarily being checked twice, but I am suppressing the errors in these files as the configurations are not quite the same. I plan to unify the configurations so this is only a temporary state.
* There were some straggler type errors after these changes somehow, so I fixed them as needed. There weren't that many.

In the future, to start type checking a file, just remove the ignore-errors directive from the top of the file.

The codemod was done with this script authored by GPT-4:

```
import glob

exclude_patterns = [
    ...
]

for pattern in exclude_patterns:
    for filepath in glob.glob(pattern, recursive=True):
        if filepath.endswith('.py'):
            with open(filepath, 'r+') as f:
                content = f.read()
                f.seek(0, 0)
                f.write('# mypy: ignore-errors\n\n' + content)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118414
Approved by: https://github.com/thiagocrepaldi, https://github.com/albanD
2024-01-27 02:44:11 +00:00
Xuehai Pan
55064a4ef9 [BE] add parentheses to kwargs unpacking func(*args, **(kwargs or {})) (#115026)
This PR adds parentheses to kwargs unpacking `func(*args, **(kwargs or {}))` for better code readability.

With/without the parentheses are semantic equivalent because they produce the same bytecode.

```console
$ echo "func(*args, **kwargs or {})" | python3 -m dis -
  0           0 RESUME                   0

  1           2 PUSH_NULL
              4 LOAD_NAME                0 (func)
              6 LOAD_NAME                1 (args)
              8 BUILD_MAP                0
             10 LOAD_NAME                2 (kwargs)
             12 JUMP_IF_TRUE_OR_POP      1 (to 16)
             14 BUILD_MAP                0
        >>   16 DICT_MERGE               1
             18 CALL_FUNCTION_EX         1
             20 POP_TOP
             22 LOAD_CONST               0 (None)
             24 RETURN_VALUE

$ echo "func(*args, **(kwargs or {}))" | python3 -m dis -
  0           0 RESUME                   0

  1           2 PUSH_NULL
              4 LOAD_NAME                0 (func)
              6 LOAD_NAME                1 (args)
              8 BUILD_MAP                0
             10 LOAD_NAME                2 (kwargs)
             12 JUMP_IF_TRUE_OR_POP      1 (to 16)
             14 BUILD_MAP                0
        >>   16 DICT_MERGE               1
             18 CALL_FUNCTION_EX         1
             20 POP_TOP
             22 LOAD_CONST               0 (None)
             24 RETURN_VALUE
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115026
Approved by: https://github.com/Skylion007
2023-12-03 20:03:26 +00:00
drisspg
c46fc46dba expose mem-eff to autograd (#110495)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110495
Approved by: https://github.com/jbschlosser
2023-11-13 17:47:40 +00:00
Peter Bell
04024926f4 Use pytree.tree_map_ everywhere (#112417)
Wherever we discard the output of `tree_map` it's better to call `tree_map_`
which doesn't unflatten the mapped results and so is a lot cheaper.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112417
Approved by: https://github.com/lezcano
ghstack dependencies: #112391, #112392, #112393, #112394
2023-10-31 15:57:06 +00:00
Peter Bell
66c32d099a Use pytree.arg_tree_leaves everywhere (#112394)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112394
Approved by: https://github.com/lezcano
ghstack dependencies: #112391, #112392, #112393
2023-10-31 15:57:06 +00:00
Peter Bell
bbd5b935e4 Use pytree.tree_leaves everywhere (#112324)
This changes all the instances I could find of `tree_flatten(...)[0]` or
`x, _ = tree_flatten` to use `tree_leaves`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112324
Approved by: https://github.com/lezcano
ghstack dependencies: #112327, #112323
2023-10-30 03:39:04 +00:00
Aaron Gokaslan
e2a3817dfd [BE] Enable C419 rule for any all shortcircuiting (#99890)
Apparently https://github.com/pytorch/pytorch/pull/78142 made torch.JIT allow for simple generator expressions which allows us to enable rules that replace unnecessary list comprehensions with generators in any/all. This was originally part of #99280 but I split it off into this PR so that it can be easily reverted should anything break.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99890
Approved by: https://github.com/justinchuby, https://github.com/kit1980, https://github.com/malfet
2023-04-25 15:02:13 +00:00
Aaron Gokaslan
67d9790985 [BE] Apply almost all remaining flake8-comprehension checks (#94676)
Applies the remaining flake8-comprehension fixes and checks. This changes replace all remaining unnecessary generator expressions with list/dict/set comprehensions which are more succinct, performant, and better supported by our torch.jit compiler. It also removes useless generators such as 'set(a for a in b)`, resolving it into just the set call.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94676
Approved by: https://github.com/ezyang
2023-02-12 01:01:25 +00:00
Aaron Gokaslan
1e2d82b8e4 [BE] Merge isinstance calls together (#94419)
Simplify and speeds up isinstance calls by checking for multiple types at the same time.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94419
Approved by: https://github.com/ezyang
2023-02-09 00:47:26 +00:00
Kazuaki Ishizaki
1cd6ebe095 Fix typos in messages under torch (#89049)
This PR fixes typos of messages in `.py` files under torch directory.
Only in `torch/onnx/symbolic_opset16.py`, fix a typo in comment to make the operator name correct.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89049
Approved by: https://github.com/lezcano
2022-11-17 04:18:14 +00:00
samdow
18d8c548f4 [Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}

This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily

Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup

### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like

```python
## PRE-PR UX
def f(mode):
  with mode.restore():  # user needs to understand this restore thing?
    ...

with Mode() as m:
  pass
f(m)
```

Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation"  step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
  with mode:
    ...
f(Mode())
```

** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-27 01:04:35 +00:00
Elias Ellison
bcc544e9d7 Add FakeCrossRef tests for backwards, Fix Layer Norm Backward Decomp (#85417)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85417
Approved by: https://github.com/ezyang
2022-09-26 17:08:14 +00:00
PyTorch MergeBot
d10de31cc8 Revert "Add FakeCrossRef tests for backwards, Fix Layer Norm Backward Decomp (#85417)"
This reverts commit 78afa0cf0c.

Reverted https://github.com/pytorch/pytorch/pull/85417 on behalf of https://github.com/clee2000 due to broke tests on trunk 78afa0cf0c
2022-09-23 17:21:43 +00:00
Elias Ellison
78afa0cf0c Add FakeCrossRef tests for backwards, Fix Layer Norm Backward Decomp (#85417)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85417
Approved by: https://github.com/ezyang
2022-09-23 15:50:03 +00:00
PyTorch MergeBot
5043457a8e Revert "Add FakeCrossRef tests for backwards, Fix Layer Norm Backward Decomp (#85417)"
This reverts commit 9c77083965.

Reverted https://github.com/pytorch/pytorch/pull/85417 on behalf of https://github.com/clee2000 due to broke tests on trunk (and pull somehow) 9c77083965
2022-09-22 15:44:38 +00:00
Elias Ellison
9c77083965 Add FakeCrossRef tests for backwards, Fix Layer Norm Backward Decomp (#85417)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85417
Approved by: https://github.com/ezyang
2022-09-22 13:03:57 +00:00
Edward Z. Yang
55ca297d4e Remove enable_recursive_torch_dispatch (#84945)
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84945
Approved by: https://github.com/soulitzer
2022-09-14 03:25:59 +00:00
Edward Z. Yang
74d0c64708 Don't use reentrant dispatch for composite compliance (#84909)
I believe these were added in to prevent changing behavior when
https://github.com/pytorch/pytorch/pull/75827 landed, but I actually
think they are unnecessary, and they are causing asserts to fire
on the subsequent PR (where I assert that tensors returned by
views MUST NOT already have view metadata associated with them.)

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84909
Approved by: https://github.com/zou3519, https://github.com/soulitzer
2022-09-13 18:41:18 +00:00
soulitzer
ba53efa6e7 Unskip CompositeCompliance tests for ARM (#83089)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83089
Approved by: https://github.com/albanD
2022-08-11 20:01:51 +00:00
Edward Z. Yang
14968d59f2 s/Compiance/Compliance (#82087)
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82087
Approved by: https://github.com/zou3519
2022-07-25 15:41:07 +00:00
Kshiteej K
8b5685da12 [composite compliance] test_operator correctness (#81600)
Time Before PR:
```
= 1111 passed, 45 skipped, 41020 deselected, 17 xfailed, 33 warnings in 52.55s =
```

Time After PR:
```
= 1105 passed, 51 skipped, 41020 deselected, 17 xfailed, 33 warnings in 70.03s (0:01:10) =
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81600
Approved by: https://github.com/zou3519
2022-07-20 21:18:56 +00:00
Kshiteej K
706b420a52 [composite compliance] check output of forward-ad with subclass args against regular tensor (#81464)
Time Before PR
```
= 880 passed, 274 skipped, 38170 deselected, 17 xfailed, 21 warnings in 808.96s (0:13:28) =
```

Time After PR
```
= 875 passed, 274 skipped, 38170 deselected, 22 xfailed, 21 warnings in 880.61s (0:14:40) =
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81464
Approved by: https://github.com/zou3519
2022-07-20 17:38:11 +00:00
kshitij12345
05ce013338 [composite compliance] check output of backward with subclass args against regular tensor (#81400)
Time Before
```
= 919 passed, 12 skipped, 38374 deselected, 36 xfailed, 31 warnings in 699.56s (0:11:39) =
```

Time After
```
= 913 passed, 12 skipped, 38374 deselected, 42 xfailed, 31 warnings in 663.96s (0:11:03) =
```

Will follow-up for operator and forward-ad
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81400
Approved by: https://github.com/zou3519
2022-07-14 05:51:05 +00:00
Richard Zou
9ee312023d [Composite compliance testing] Refactor check_forward_ad_formula to accept Callable (#81239)
Like https://github.com/pytorch/pytorch/pull/81059; this PR addresses
the review comments.

Test Plan:
- run tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81239
Approved by: https://github.com/ezyang
2022-07-11 20:48:18 +00:00
Richard Zou
528ee0fa75 Fix composite compliance testing to check for .item() calls (#81060)
Composite compliance is supposed to check if a composite function
calls .item()
([ref](39db8b3823/torch/testing/_internal/composite_compliance.py (L135-L138))).
This PR fixes that and adds some more documentation.

Why do we need this check? The original motivations are that Tensor subclasses
may not support .item calls (e.g. vmap and ProxyTensor).
There is no way for these subclasses to meaningfully override the .item() calls
in composite functions that exist inside the PyTorch framework without raising
an error* so we should aim to rewrite composite operations to not call .item().

*We're open to other solutions, this is just the one we decided on when we
wrote composite compliance testing and these tests help us keep track of the
failing functionality.

Test Plan:
- wait for tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81060
Approved by: https://github.com/ezyang
2022-07-11 18:37:50 +00:00
Richard Zou
d253cdd8ff [composite compliance testing] Refactor check_backward_formula to accept Callable (#81059)
Maybe niche, but for one-off debugging purposes, I want a variant of
check_backward_formula that accepts a callable rather than an OpInfo.
This is because when debugging, I try to create a repro that does not
involve OpInfos because OpInfos are difficult to deal with (they have
a lot of sample inputs, I may want to test my own sample inputs without
creating a new OpInfo, etc).

This PR refactors check_backward_formula so that it accepts a Callable
instead of an OpInfo. Example usage:

```
import torch
from torch.testing._internal.composite_compliance import check_backward_formula

x = torch.tensor([[1., 1.], [1., 0.]], requires_grad=True)
args = (x, 1)

check_backward_formula_callable(torch.prod, args, {})
```

Test Plan:
- run existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81059
Approved by: https://github.com/kshitij12345, https://github.com/ezyang
2022-07-11 18:37:50 +00:00
Richard Zou
6b0651209e [composite compliance testing] remove tree_flatten hack (#81057)
Previously we had a hack for tree_flatten not supporting
torch.return_types. That was fixed a while ago
(https://github.com/pytorch/pytorch/issues/74624) so we can delete the
hack.

Test Plan:
- wait for tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81057
Approved by: https://github.com/kshitij12345, https://github.com/ezyang
2022-07-11 18:37:49 +00:00
kshitij12345
e51c63da65 [composite compliance] preserve stride correctly for non-contiguous tensor with requires_grad=True (#81035)
Fixes https://github.com/pytorch/pytorch/issues/80858

Also removes the skips on `ravel`, `lstsq` and `lstsq_grad_oriented`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81035
Approved by: https://github.com/zou3519
2022-07-07 14:32:47 +00:00
soulitzer
ed71f88531 Fix composite compliance tensor for inplace views
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79902

Approved by: https://github.com/zou3519
2022-06-21 17:36:30 +00:00
Kshiteej K
04b98df87a [fix] composite compliance: eig, eigh, symeig (#79698)
Ref: https://github.com/pytorch/pytorch/issues/69991
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79698
Approved by: https://github.com/Lezcano, https://github.com/albanD
2022-06-17 14:13:04 +00:00
Elias Ellison
678213ead2 Fake Tensor Part 1
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77969

Approved by: https://github.com/ezyang
2022-05-31 16:20:35 +00:00
samdow
598e7e5f19 [Reland] Change 'python mode' to 'torch dispatch mode'
Changes Python Mode name to Torch Dispatch Mode because there is now a Torch Function Mode, so Torch Dispatch Mode and Torch Function Mode are consistent with each other
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76562
Approved by: https://github.com/zou3519, https://github.com/albanD
2022-05-02 20:06:43 +00:00
PyTorch MergeBot
395a620a4f Revert "Change 'python mode' to 'torch dispatch mode'"
This reverts commit 7203a73986.

Reverted https://github.com/pytorch/pytorch/pull/76562 on behalf of https://github.com/janeyx99
2022-05-02 14:42:11 +00:00
samdow
7203a73986 Change 'python mode' to 'torch dispatch mode'
Changes Python Mode name to Torch Dispatch Mode because there is now a Torch Function Mode, so Torch Dispatch Mode and Torch Function Mode are consistent with each other
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76562
Approved by: https://github.com/zou3519
2022-05-02 13:33:58 +00:00
kshitij12345
c1ced8ff72 [composite compliance] add test for fwd AD
Fixes https://github.com/pytorch/pytorch/issues/74678

Test timings:
```
======================================= 756 passed, 99 skipped, 13864 deselected, 76 xfailed, 16 warnings in 278.35s (0:04:38) =======================================
```

Slowest ops
```
======================================================================== slowest 20 durations ========================================================================
32.16s call     test/test_ops.py::TestCompositeComplianceCUDA::test_forward_ad_nn_functional_instance_norm_cuda_float32
30.51s call     test/test_ops.py::TestCompositeComplianceCPU::test_forward_ad_nn_functional_instance_norm_cpu_float32
9.89s call     test/test_ops.py::TestCompositeComplianceCUDA::test_forward_ad__masked_norm_cuda_float32
8.54s call     test/test_ops.py::TestCompositeComplianceCPU::test_forward_ad__masked_norm_cpu_float32
8.52s call     test/test_ops.py::TestCompositeComplianceCUDA::test_forward_ad_diff_cuda_float32
8.33s call     test/test_ops.py::TestCompositeComplianceCUDA::test_forward_ad_linalg_solve_triangular_cuda_float32
8.08s call     test/test_ops.py::TestCompositeComplianceCPU::test_forward_ad_linalg_solve_triangular_cpu_float32
8.03s call     test/test_ops.py::TestCompositeComplianceCPU::test_forward_ad_diff_cpu_float32
6.52s call     test/test_ops.py::TestCompositeComplianceCUDA::test_forward_ad_cov_cuda_float32
5.77s call     test/test_ops.py::TestCompositeComplianceCPU::test_forward_ad_cov_cpu_float32
4.12s call     test/test_ops.py::TestCompositeComplianceCUDA::test_forward_ad_lu_solve_cuda_float32
3.78s call     test/test_ops.py::TestCompositeComplianceCUDA::test_forward_ad__masked_std_cuda_float32
3.67s call     test/test_ops.py::TestCompositeComplianceCUDA::test_forward_ad_gradient_cuda_float32
3.55s call     test/test_ops.py::TestCompositeComplianceCUDA::test_forward_ad__masked_var_cuda_float32
3.47s call     test/test_ops.py::TestCompositeComplianceCUDA::test_forward_ad_nn_functional_max_pool2d_cuda_float32
3.42s call     test/test_ops.py::TestCompositeComplianceCUDA::test_forward_ad_nn_functional_batch_norm_without_cudnn_cuda_float32
3.40s call     test/test_ops.py::TestCompositeComplianceCPU::test_forward_ad_nn_functional_max_pool2d_cpu_float32
3.30s call     test/test_ops.py::TestCompositeComplianceCPU::test_forward_ad__masked_std_cpu_float32
3.30s call     test/test_ops.py::TestCompositeComplianceCPU::test_forward_ad_gradient_cpu_float32
3.28s call     test/test_ops.py::TestCompositeComplianceCUDA::test_forward_ad_nn_functional_batch_norm_cuda_float32
====================================================================== short test summary info =======================================================================
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75178
Approved by: https://github.com/zou3519
2022-04-25 15:15:48 +00:00
albanD
cd0591dff3 Change default TLS behavior in dispatch to favor is-a style
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75827

Approved by: https://github.com/ezyang
2022-04-20 17:32:29 +00:00
Peter Bell
58fb3f018e Fix conjugate bit discrepancy in composite compliance
When testing composite compliance, the conj bit and neg bit are not
propagated to the wrapper tensor. This leads to problems when a
composite operator has two paths depending on whether one of these
bits are set, since the non-conjugated path will always be taken.

For example, `at::real` effectively does
```cpp
view_as_real(tensor.is_conj() ? tensor.conj() : tensor)
```
which will never call `conj()` because the `CompositeCompliantTensor`
never has has the conj bit set. The result is `view_as_real` fails
when `r.elem` does have the conj bit set.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/75830

Approved by: https://github.com/zou3519
2022-04-19 13:59:28 +00:00
Richard Zou
e832eedd29 Composite Compliance testing for backward formulas (#74646)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74646

The OpInfo-based test, given an operator and sample inputs,
checks all permutations of {inputs, grad_output} being either
{CompositeCompliantTensor, regular Tensor}, running them through a
forward pass and a backward pass.

Test Plan: - wait for tests

Reviewed By: albanD

Differential Revision: D35186860

Pulled By: zou3519

fbshipit-source-id: 8b2577dd6106c05db2ab583bbefd10545fdd8adf
(cherry picked from commit 3f5c3793715af9a8d4db06690c5faa7256a82645)
2022-03-28 22:12:41 +00:00
Richard Zou
80d64b365a Test case where some inputs are Tensor Subclasses in CompositeCompiance (#74645)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74645

This PR adds tests for when only some inputs are Tensor Subclasses.

Why is this important to test?
==============================

Consider the following hypothetical out-of-place operation:
```
def my_add(x, y):
  result = x.clone()
  result.add_(y)
  return result
```

You may expect this to work the same as torch.add. If x is not a Tensor
Subclass, but y is a Tensor subclass, then this returns us a regular
Tensor, NOT a Tensor subclass!

This is exactly the type of in-place operations that causes `vmap` to
fail and will be problematic for certain Tensor Subclasses in the future
so we're adding tests to make sure Composite pytorch operations don't do
this.

What exactly does this PR do?
=============================
Composite compliance now takes a sample input and produces a test case
where some of the sample inputs are Tensor Subclasses. It then sends
this through the original operation, once with Python Mode and one
without.

(Why once with Python Mode? Because we want to use it to detect the
pattern of "create a Tensor and call resize_ on it")

Finally, it repeats this process for all possiblities where the inputs
are Tensor subclasses. For example, if the sample input is (x, y), then
we test all four of the following cases:
- Subclass(x), y
- x, Subclass(y)
- Subclass(x), Subclass(y)
- x, y

Test Plan
=========
- run tests

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D35186862

Pulled By: zou3519

fbshipit-source-id: 102477507b56583463668db7523a6586d92b357d
(cherry picked from commit bfcb087244b0598abb270f7c26d472482f00b5e2)
2022-03-28 22:12:41 +00:00
Richard Zou
c96f321804 Move CompositeCompliance tests to their own TestCase (#74644)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74644

This is in preparation for me adding additional tests for:
1. composite compliance of autograd formulas
2. composite compliance of forward-mode AD formulas

This PR also changes these tests to run on both CPU and CUDA. Previously
they were just run on CPU, but it turns out there's a lot of branching
on the device in composite operations in PyTorch today :/

Test Plan: - wait for tests

Reviewed By: albanD

Differential Revision: D35186861

Pulled By: zou3519

fbshipit-source-id: d974592a7547f71ef26ff0740bf453f7d335d55a
(cherry picked from commit 773b43394c2406502a6e386a30eb003a73861f13)
2022-03-28 22:12:40 +00:00
anjali411
086645ad77 Update __torch_dispatch__ to return op overload instead of the opoverload packet function (#72673)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72673

Test Plan: Imported from OSS

Reviewed By: mruberry

Differential Revision: D34627164

Pulled By: anjali411

fbshipit-source-id: 3cb6406a392d530bf9da36b4d8e0a62b30e6497e
(cherry picked from commit 65b85a0a67df4d0f16ac8964e2b685d478a610fb)
2022-03-07 22:38:42 +00:00
Richard Zou
6fea7499c2 CompositeImplicitAutograd compliance testing (#65819)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65819

Related to #61669.

Functions registered as CompositeImplicitAutograd MUST work for most, if
not all, backends. This includes Tensor subclasses.

To achieve this, we (PyTorch) impose a set of constraints on how a
CompositeImplicitAutograd function can be written.

Concretely, this PR adds tests for all OpInfos that checks for
compliance. The things that get tested in this PR apply to composite
ops and are that:
- the op does not change the metadata of a Tensor without performing
dispatches
- the op does not call set_ or resize_
- the op does not directly access the data ptr

The mechanism for the test is to create a new __torch_dispatch__
object, CompositeCompliantTensor. For each operator, we wrap all inputs
in CompositeCompliantTensor, turn on python mode for it,
and send it through the operator.

Non-CompositeImplicitAutograd operators will pass the test because they
perform a dispatch to backend code. Here's how CompositeCompliantTensor
catches problems:

- If it sees set_ or resize_ getting called, it will directly error
out
- After each operation, CompositeCompliantTensor checks to make sure
that its metadata is consistent with that of the thing it is wrapping.
If the CompositeImplicitAutograd op modifes the metadata directly
(through e.g. the TensorImpl API) then the metadata will go out of sync.
- If data_ptr gets called, that returns a nice error (because the
storage is meta).

CompositeCompliantTensor is written in an interesting way. First off,
if a view operation occurs (e.g. `B = A.view_op(...)`), then B.storage()
must alias A.storage() where B.storage() is CompositeCompliantTensor's
storage, NOT the storage of the tensor it is wrapping. This is an
invariant in autograd, see #62182 for details. To handle
this we replay the view on A's storage and set it as B's storage.

Secondly, there are cases where the metadata is allowed to go out of
sync. I believe this is only possible with in-place view functions, like
transpose_, t_, squeeze_, unsqueeze_. Those are special cased.

Finally, I added a new section to aten/src/ATen/native/README.md about
what it means to be CompositeImplicitAutograd Compliant

Test Plan: - run tests

Reviewed By: ezyang, bdhirsh

Differential Revision: D31268369

Pulled By: zou3519

fbshipit-source-id: 31634b1cbe1778ab30196013cfc376ef9bd2e8b1
2021-11-30 07:35:22 -08:00