Applies the remaining flake8-comprehension fixes and checks. This changes replace all remaining unnecessary generator expressions with list/dict/set comprehensions which are more succinct, performant, and better supported by our torch.jit compiler. It also removes useless generators such as 'set(a for a in b)`, resolving it into just the set call.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94676
Approved by: https://github.com/ezyang
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}
This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily
Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup
### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like
```python
## PRE-PR UX
def f(mode):
with mode.restore(): # user needs to understand this restore thing?
...
with Mode() as m:
pass
f(m)
```
Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
with mode:
...
f(Mode())
```
** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
Composite compliance is supposed to check if a composite function
calls .item()
([ref](39db8b3823/torch/testing/_internal/composite_compliance.py (L135-L138))).
This PR fixes that and adds some more documentation.
Why do we need this check? The original motivations are that Tensor subclasses
may not support .item calls (e.g. vmap and ProxyTensor).
There is no way for these subclasses to meaningfully override the .item() calls
in composite functions that exist inside the PyTorch framework without raising
an error* so we should aim to rewrite composite operations to not call .item().
*We're open to other solutions, this is just the one we decided on when we
wrote composite compliance testing and these tests help us keep track of the
failing functionality.
Test Plan:
- wait for tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81060
Approved by: https://github.com/ezyang
Maybe niche, but for one-off debugging purposes, I want a variant of
check_backward_formula that accepts a callable rather than an OpInfo.
This is because when debugging, I try to create a repro that does not
involve OpInfos because OpInfos are difficult to deal with (they have
a lot of sample inputs, I may want to test my own sample inputs without
creating a new OpInfo, etc).
This PR refactors check_backward_formula so that it accepts a Callable
instead of an OpInfo. Example usage:
```
import torch
from torch.testing._internal.composite_compliance import check_backward_formula
x = torch.tensor([[1., 1.], [1., 0.]], requires_grad=True)
args = (x, 1)
check_backward_formula_callable(torch.prod, args, {})
```
Test Plan:
- run existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81059
Approved by: https://github.com/kshitij12345, https://github.com/ezyang
When testing composite compliance, the conj bit and neg bit are not
propagated to the wrapper tensor. This leads to problems when a
composite operator has two paths depending on whether one of these
bits are set, since the non-conjugated path will always be taken.
For example, `at::real` effectively does
```cpp
view_as_real(tensor.is_conj() ? tensor.conj() : tensor)
```
which will never call `conj()` because the `CompositeCompliantTensor`
never has has the conj bit set. The result is `view_as_real` fails
when `r.elem` does have the conj bit set.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75830
Approved by: https://github.com/zou3519
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74646
The OpInfo-based test, given an operator and sample inputs,
checks all permutations of {inputs, grad_output} being either
{CompositeCompliantTensor, regular Tensor}, running them through a
forward pass and a backward pass.
Test Plan: - wait for tests
Reviewed By: albanD
Differential Revision: D35186860
Pulled By: zou3519
fbshipit-source-id: 8b2577dd6106c05db2ab583bbefd10545fdd8adf
(cherry picked from commit 3f5c3793715af9a8d4db06690c5faa7256a82645)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74645
This PR adds tests for when only some inputs are Tensor Subclasses.
Why is this important to test?
==============================
Consider the following hypothetical out-of-place operation:
```
def my_add(x, y):
result = x.clone()
result.add_(y)
return result
```
You may expect this to work the same as torch.add. If x is not a Tensor
Subclass, but y is a Tensor subclass, then this returns us a regular
Tensor, NOT a Tensor subclass!
This is exactly the type of in-place operations that causes `vmap` to
fail and will be problematic for certain Tensor Subclasses in the future
so we're adding tests to make sure Composite pytorch operations don't do
this.
What exactly does this PR do?
=============================
Composite compliance now takes a sample input and produces a test case
where some of the sample inputs are Tensor Subclasses. It then sends
this through the original operation, once with Python Mode and one
without.
(Why once with Python Mode? Because we want to use it to detect the
pattern of "create a Tensor and call resize_ on it")
Finally, it repeats this process for all possiblities where the inputs
are Tensor subclasses. For example, if the sample input is (x, y), then
we test all four of the following cases:
- Subclass(x), y
- x, Subclass(y)
- Subclass(x), Subclass(y)
- x, y
Test Plan
=========
- run tests
Test Plan: Imported from OSS
Reviewed By: albanD
Differential Revision: D35186862
Pulled By: zou3519
fbshipit-source-id: 102477507b56583463668db7523a6586d92b357d
(cherry picked from commit bfcb087244b0598abb270f7c26d472482f00b5e2)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74644
This is in preparation for me adding additional tests for:
1. composite compliance of autograd formulas
2. composite compliance of forward-mode AD formulas
This PR also changes these tests to run on both CPU and CUDA. Previously
they were just run on CPU, but it turns out there's a lot of branching
on the device in composite operations in PyTorch today :/
Test Plan: - wait for tests
Reviewed By: albanD
Differential Revision: D35186861
Pulled By: zou3519
fbshipit-source-id: d974592a7547f71ef26ff0740bf453f7d335d55a
(cherry picked from commit 773b43394c2406502a6e386a30eb003a73861f13)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65819
Related to #61669.
Functions registered as CompositeImplicitAutograd MUST work for most, if
not all, backends. This includes Tensor subclasses.
To achieve this, we (PyTorch) impose a set of constraints on how a
CompositeImplicitAutograd function can be written.
Concretely, this PR adds tests for all OpInfos that checks for
compliance. The things that get tested in this PR apply to composite
ops and are that:
- the op does not change the metadata of a Tensor without performing
dispatches
- the op does not call set_ or resize_
- the op does not directly access the data ptr
The mechanism for the test is to create a new __torch_dispatch__
object, CompositeCompliantTensor. For each operator, we wrap all inputs
in CompositeCompliantTensor, turn on python mode for it,
and send it through the operator.
Non-CompositeImplicitAutograd operators will pass the test because they
perform a dispatch to backend code. Here's how CompositeCompliantTensor
catches problems:
- If it sees set_ or resize_ getting called, it will directly error
out
- After each operation, CompositeCompliantTensor checks to make sure
that its metadata is consistent with that of the thing it is wrapping.
If the CompositeImplicitAutograd op modifes the metadata directly
(through e.g. the TensorImpl API) then the metadata will go out of sync.
- If data_ptr gets called, that returns a nice error (because the
storage is meta).
CompositeCompliantTensor is written in an interesting way. First off,
if a view operation occurs (e.g. `B = A.view_op(...)`), then B.storage()
must alias A.storage() where B.storage() is CompositeCompliantTensor's
storage, NOT the storage of the tensor it is wrapping. This is an
invariant in autograd, see #62182 for details. To handle
this we replay the view on A's storage and set it as B's storage.
Secondly, there are cases where the metadata is allowed to go out of
sync. I believe this is only possible with in-place view functions, like
transpose_, t_, squeeze_, unsqueeze_. Those are special cased.
Finally, I added a new section to aten/src/ATen/native/README.md about
what it means to be CompositeImplicitAutograd Compliant
Test Plan: - run tests
Reviewed By: ezyang, bdhirsh
Differential Revision: D31268369
Pulled By: zou3519
fbshipit-source-id: 31634b1cbe1778ab30196013cfc376ef9bd2e8b1