Commit Graph

348 Commits

Author SHA1 Message Date
Peter Bell
a2a8c1fda0 [AOTDispatch] Return mutated inputs directly when keeping mutations (#120514)
Fixes #120242

The example from the issue now results in the graph
```python
def forward(self, arg0_1, arg1_1):
    sin = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
    copy_ = torch.ops.aten.copy_.default(arg1_1, sin);  arg1_1 = sin = None
    return (copy_,)
```

and the corresponding inductor kernel eliminates the intermediate buffer
completely

```python
def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (5, ), (1, ))
    assert_size_stride(arg1_1, (5, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        # Source Nodes: [sin], Original ATen: [aten.sin]
        stream0 = get_raw_stream(0)
        triton_poi_fused_sin_0.run(arg0_1, arg1_1, 5, grid=grid(5), stream=stream0)
        del arg0_1
    return (arg1_1, )
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120514
Approved by: https://github.com/ezyang, https://github.com/oulgen, https://github.com/lezcano
2024-03-08 16:33:26 +00:00
Peter Bell
eae9751e82 Fix linalg_eigvals invalid use of composite dispatch key (#121142)
`linalg_eigvals_out` calls into a dispatch stub, so only supports CPU and CUDA
strided tensors but incorrectly claimed to be a composite op. `linalg_eigvals`
also shouldn't defer to the out variant inside a `CompositeImplicitAutograd` op
as not all types support out variants. Instead, I add a new helper
`_linalg_eigvals` which does the same thing in a non-composite operator.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121142
Approved by: https://github.com/lezcano
2024-03-05 21:13:27 +00:00
Tugsbayasgalan Manlaibaatar
c646030cd2 Support higher order op functionalization in predispatch IR (#115314)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115314
Approved by: https://github.com/bdhirsh
2024-03-01 09:13:47 +00:00
Elias Ellison
7ebfe21724 Fix nll loss dynamo failure (#120805)
Fix for https://github.com/pytorch/pytorch/issues/119791 Part of dynamo bug bash
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120805
Approved by: https://github.com/Skylion007, https://github.com/zou3519, https://github.com/malfet
2024-02-29 22:34:49 +00:00
Isuru Fernando
435063aa89 Decomposition for upsample_linear{1d, 3d} (#114774)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114774
Approved by: https://github.com/lezcano, https://github.com/vfdev-5, https://github.com/peterbell10
2024-02-27 11:57:45 +00:00
Isuru Fernando
b7df3bba62 add decomposition for frexp (#119217)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119217
Approved by: https://github.com/peterbell10
ghstack dependencies: #119284, #120027
2024-02-23 21:52:42 +00:00
Joel Schlosser
e7eab2f07e Fix to keep stride in return_and_correct_aliasing() (#117860)
Fixes #117794

Fix tripped the assert here: 86dedebeaf/torch/utils/_python_dispatch.py (L216)

From investigation: I found that functionalization of an in-place op (`mul_` in this test case) results in the strides of `TwoTensor`'s `a` / `b` components being mutated to be contiguous. This is not reflected in the outer tensor, causing the assert to be tripped.

After discussion with Brian, I address this in this PR by disallowing input mutations on non-contiguous tensor subclass inputs for now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117860
Approved by: https://github.com/bdhirsh
2024-02-21 19:15:27 +00:00
gs-olive
e0f6fa6a7c Windows Dynamo Error Removal CI Check (#115969)
Rebase of #111313 onto `main`, for CI validation

Co-authored-by: Stella Laurenzo <stellaraccident@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115969
Approved by: https://github.com/PaliC, https://github.com/thiagocrepaldi
2024-02-14 21:14:36 +00:00
Brian Hirsh
f9f0c67445 beef up non-overlapping checks for detecting false aliasing of graph inputs (#119826)
This extra check is needed for some more complicated parameter sizes/strides for an internal model

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119826
Approved by: https://github.com/albanD
2024-02-14 01:46:30 +00:00
PyTorch MergeBot
4a5b2cd6cb Revert "Windows Dynamo Error Removal CI Check (#115969)"
This reverts commit 45e7af5818.

Reverted https://github.com/pytorch/pytorch/pull/115969 on behalf of https://github.com/PaliC due to this pr ended up breaking some of our periodic tests ([comment](https://github.com/pytorch/pytorch/pull/115969#issuecomment-1942934386))
2024-02-14 01:11:46 +00:00
Pearu Peterson
2c91e13afc Add lowerings to special functions (#119187)
As in the title.

In addition, the PR introduces infrastructure for lowerings of pointwise functions that have both cpp and triton implementations available.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119187
Approved by: https://github.com/peterbell10
2024-02-11 16:35:40 +00:00
suo
0f12c0af44 [export] allow user input mutation in aot_export (#119356)
This PR enables input mutation in aot_export by removing the guard and ensuring that the GraphSignature is properly wired up.

This allows to undo the gross hack in torch.export where we lift user inputs to buffers in order to get around aot_export upstream support. It also makes input mutation work properly for non-strict mode.

Mutations on inputs that require_grad are still banned (I added a test for a non-parameter input as well, just to make sure).

Differential Revision: [D53507440](https://our.internmc.facebook.com/intern/diff/D53507440/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119356
Approved by: https://github.com/bdhirsh, https://github.com/zhxchen17, https://github.com/titaiwangms
2024-02-08 22:02:24 +00:00
gs-olive
45e7af5818 Windows Dynamo Error Removal CI Check (#115969)
Rebase of #111313 onto `main`, for CI validation

Co-authored-by: Stella Laurenzo <stellaraccident@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115969
Approved by: https://github.com/ezyang
2024-02-08 21:23:45 +00:00
Isuru Fernando
81d12846dc Add decomp for pixel_shuffle/unshuffle (#118239)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118239
Approved by: https://github.com/peterbell10
2024-01-31 18:34:21 +00:00
Tugsbayasgalan Manlaibaatar
fa1e89b337 Ban mutation on dropout outputs in export (#117879)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117879
Approved by: https://github.com/ezyang
ghstack dependencies: #117811
2024-01-21 04:53:40 +00:00
Brian Hirsh
2db53a01e5 propagate torch stack trace metadata to copy_() nodes during input mutations (#117587)
Tested by running the below script:
```
import torch
@torch.compile(backend="aot_eager", fullgraph=True)
def f(x):
    y = x.view(-1)
    y.mul_(2)
    return

x = torch.ones(4)
f(x)
```

Which gives me this ATen graph (notice that the copy_() node is bundled under the stacktrace for `mul_(2)`):
```
 ===== Forward graph 0 =====
 <eval_with_key>.2 from /data/users/hirsheybar/e/pytorch/torch/fx/experimental/proxy_tensor.py:521 in wrapped class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: "f32[4]"):
        # File: /data/users/hirsheybar/e/pytorch/tmp5.py:8, code: y = x.view(-1)
        view: "f32[4]" = torch.ops.aten.view.default(arg0_1, [-1])

        # File: /data/users/hirsheybar/e/pytorch/tmp5.py:9, code: y.mul_(2)
        mul: "f32[4]" = torch.ops.aten.mul.Tensor(view, 2);  view = None
        view_1: "f32[4]" = torch.ops.aten.view.default(mul, [4]);  mul = None
        copy_: "f32[4]" = torch.ops.aten.copy_.default(arg0_1, view_1);  arg0_1 = view_1 = None
        return ()

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117587
Approved by: https://github.com/eellison
2024-01-17 23:07:45 +00:00
Joel Schlosser
7956ca16e6 Enable reverse view_funcs by default for python subclasses (#116512)
Part 3 of implementation for general [subclass view fake-ification](https://docs.google.com/document/d/1C5taWiplmX7nKiURXDOAZG2W5VNJ2iV0fQFq92H0Cxw).

Changes codegen to generate `view_func()` / `rev_view_func()` by default for python subclasses. With `view_func()` existing more often now, the lazy view rebase logic [here](f10c3f4184/torch/csrc/autograd/variable.cpp (L665-L695)) causes some slight behavior changes for in-place ops on views:
* Additional view nodes are inserted into output graphs, changing their string representation, although they are functionally the same. The extra nodes are removed in AOTAutograd's DCE pass.
* When `t` is a `FunctionalTensor`, calling `t.grad_fn` will now invoke `view_func()`; we need to make sure we're operating in a `FunctionalTensorMode` so the view op calls succeed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116512
Approved by: https://github.com/bdhirsh, https://github.com/soulitzer
ghstack dependencies: #115894
2024-01-05 16:48:12 +00:00
Joel Schlosser
3c21264c9b Introduce reverse view_funcs (#115894)
Part 2 of implementation for general [subclass view fake-ification](https://docs.google.com/document/d/1C5taWiplmX7nKiURXDOAZG2W5VNJ2iV0fQFq92H0Cxw).

Details:
* Codegen `rev_view_func()` alongside `view_func()`
    * Reverse view_func gives you a "base" from a "view": `rev_view_func(new_view) -> new_base` AKA it plays the original view backwards
* Utilizes the functional inverses defined in `FunctionalInverses.cpp`, passing `InverseReturnMode::AlwaysView`
* Manually implements functional inverses for `narrow()` and `chunk()`
* **NB: Multi-output views now set view_func() / rev_view_func() for each of the output views!**
    * Due to this, the `as_view()` overload that operates on a list of views is scrapped in favor of iteration via codegen

Example codegen in `ADInplaceOrViewTypeN.cpp`:
```cpp
at::Tensor narrow(c10::DispatchKeySet ks, const at::Tensor & self, int64_t dim, c10::SymInt start, c10::SymInt length) {
  auto _tmp = ([&]() {
    at::AutoDispatchBelowADInplaceOrView guard;
    return at::_ops::narrow::redispatch(ks & c10::after_ADInplaceOrView_keyset, self, dim, start, length);
  })();
  std::function<at::Tensor(const at::Tensor&)> func=nullptr;
  std::function<at::Tensor(const at::Tensor&)> rev_func=nullptr;
  if (false || !self.unsafeGetTensorImpl()->support_as_strided() ||
      c10::AutogradState::get_tls_state().get_view_replay_enabled()) {
    func = [=](const at::Tensor& input_base) {
      return at::_ops::narrow::call(input_base, dim, start, length);
    };
    rev_func = [=](const at::Tensor& input_view) {
      // NB: args from narrow() signature are passed along to the inverse
      return at::functionalization::FunctionalInverses::narrow_copy_inverse(self, input_view, at::functionalization::InverseReturnMode::AlwaysView, dim, start, length);
    };
  }
  auto result = as_view(/* base */ self, /* output */ _tmp, /* is_bw_differentiable */ true, /* is_fw_differentiable */ true, /* view_func */ func, /* rev_view_func */ rev_func, /* creation_meta */ InferenceMode::is_enabled() ? CreationMeta::INFERENCE_MODE : (at::GradMode::is_enabled() ? CreationMeta::DEFAULT : CreationMeta::NO_GRAD_MODE));
  return result;
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115894
Approved by: https://github.com/soulitzer
2024-01-05 16:48:12 +00:00
Tugsbayasgalan Manlaibaatar
dfc898ede4 Don't decompose functional ops in predispatch functionalization (#116383)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116383
Approved by: https://github.com/bdhirsh
ghstack dependencies: #115188, #115210
2023-12-28 11:54:04 +00:00
Tugsbayasgalan Manlaibaatar
76b1d44d57 pre_dispatch aot_export (#115188)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115188
Approved by: https://github.com/bdhirsh
2023-12-25 04:51:21 +00:00
PyTorch MergeBot
0567f71ac6 Revert " pre_dispatch aot_export (#115188)"
This reverts commit a267d67350.

Reverted https://github.com/pytorch/pytorch/pull/115188 on behalf of https://github.com/jeanschmidt due to sadly, it is required to revert this commit in order to revert https://github.com/pytorch/pytorch/pull/115454 ([comment](https://github.com/pytorch/pytorch/pull/115188#issuecomment-1866310014))
2023-12-21 14:03:18 +00:00
Tugsbayasgalan Manlaibaatar
a267d67350 pre_dispatch aot_export (#115188)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115188
Approved by: https://github.com/bdhirsh
2023-12-20 21:36:25 +00:00
Tugsbayasgalan Manlaibaatar
d85314c95c Support Predispatch functionalization (#113728)
In this PR, we are implementing Functionalization on pre-dispatch graph. Today, every dispatch key except for Dispatchkey.Python has a dedicated mode stack in python. PreDispatch tracing relies on this behaviour by pushing ProxyTorchDispatchMode to Dispatchkey.PreDispatch mode stack and handle the dispatching logic in python. To make pre-dispatch functionalization work, we now need to push FunctionalTensorMode on DispatchKey.PreDispatch mode stack and make sure it runs before ProxyTorchDispatchMode. (this is very similar to how post-dispatch tracing work). Here are some design decisions we made for this flow to work:

1. FunctionalTensorMode internally calls C++ functionalize key. Since C++ functionalization goes after PreDispatch, if we are not careful, we will keep re-entering into PreDispatch key. We solve this by directly dispatching to C++ Functionalize key.

2. We delete mode_stack_per_key logic because the only realistic time it is exercised is for PreDispatch and it is in general not safe to have a plain list because FunctionalTensorMode and ProxyTorchDispatchMode ordering matter and it is hard to enforce it on plain list. Instead, now we have a private class that tracks PreDispatch mode stack.

3.  We will still run CompositeImplicitAutograd decomps in this PR, and disable this logic later as a followup.

Some missing bits after this PR:
1. Preserving autograd ops in a functional form. Right now they still show up in the graph but in a "non-functional" way.
2. Turn off CompositeImplicitAutograd decomps
3. Functionalizing HOO

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113728
Approved by: https://github.com/bdhirsh
2023-12-19 20:28:35 +00:00
Mikayla Gawarecki
ac60a70e06 Migrated loss functions to ModuleInfos (#115584)
Migrates most tests in `common_nn.py:criterion_tests` to ModuleInfos.

**I can split this up if it is too large to review**

What this PR does not include:
- [`no_batch_dim` tests](https://github.com/pytorch/pytorch/blob/main/torch/testing/_internal/common_nn.py#L3995-L4112)
- [tests that use the functional variant of the loss function and `wrap_functional`](https://github.com/pytorch/pytorch/blob/main/torch/testing/_internal/common_nn.py#L1079-L1128)

#### On test times
This PR increases test time by ~58s locally
Before this PR:
```
>>> python test/test_nn.py -k Loss
Ran 1003 tests in 28.977s
```
After this PR
```
>>> python test/test_nn.py -k Loss
Ran 368 tests in 23.073s
```

```
>>> python test/test_modules.py -k Loss
Ran 836 tests in 63.900s
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115584
Approved by: https://github.com/janeyx99
ghstack dependencies: #115617
2023-12-14 16:21:05 +00:00
PyTorch MergeBot
626b7dc847 Revert "Migrated loss functions to ModuleInfos (#115584)"
This reverts commit f138b08d2e.

Reverted https://github.com/pytorch/pytorch/pull/115584 on behalf of https://github.com/atalman due to OSS CI oncall, breaks slow test ([comment](https://github.com/pytorch/pytorch/pull/115584#issuecomment-1854855080))
2023-12-13 23:34:30 +00:00
Mikayla Gawarecki
f138b08d2e Migrated loss functions to ModuleInfos (#115584)
Migrates most tests in `common_nn.py:criterion_tests` to ModuleInfos.

**I can split this up if it is too large to review**

What this PR does not include:
- [`no_batch_dim` tests](https://github.com/pytorch/pytorch/blob/main/torch/testing/_internal/common_nn.py#L3995-L4112)
- [tests that use the functional variant of the loss function and `wrap_functional`](https://github.com/pytorch/pytorch/blob/main/torch/testing/_internal/common_nn.py#L1079-L1128)

#### On test times
This PR increases test time by ~58s locally
Before this PR:
```
>>> python test/test_nn.py -k Loss
Ran 1003 tests in 28.977s
```
After this PR
```
>>> python test/test_nn.py -k Loss
Ran 368 tests in 23.073s
```

```
>>> python test/test_modules.py -k Loss
Ran 836 tests in 63.900s
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115584
Approved by: https://github.com/janeyx99
ghstack dependencies: #115617
2023-12-12 22:20:20 +00:00
Isuru Fernando
505574c46a Add decomposition for torch.block_diag (#115096)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115096
Approved by: https://github.com/peterbell10
2023-12-11 20:04:22 +00:00
Joel Schlosser
22704426c3 Expand dynamic dims support for traceable subclasses (#114311)
Continuation of #112185, following the design in this [doc](https://docs.google.com/document/d/1ipSxcTzEMMOAPvxP-YJlD5JBZZmIGgh8Q34ixtOUCRo).

Summary:
* Introduce `SubclassSymbolicPolicy` containing separate dynamic dim / constraint policies for the outer and inner tensors
    * Expand the automatic dynamic algorithm to recurse into inner tensors and produce one of these for a subclass instance
    * Maintain legacy behavior for subclasses by recursively calling `mark_dynamic()` on inner tensors *of the same dim as outer* when `mark_dynamic(outer, ...)` is called
    * Addresses this: 6a86cf00ad/torch/_dynamo/variables/builder.py (L1750)
* Add `outer_size` and `outer_stride` arguments to `__tensor_unflatten__()` so that you can find out what symbols were allocated for the outer size / stride (you are expected to return a tensor that compares equal to the outer symbols)
    * Signatures now:
    ```python
    # attrs is a list of inner tensor attributes on x; inner_tensor = getattr(x, attr)
    # ctx is anything useful for rebuilding the class we want to guard on
    attrs, ctx = x.__tensor_flatten__()
    ...
    # inner_tensors is a dict of {attr -> tensor}
    # ctx is taken unmodified from flattening and (eventually) guarded on
    # outer_size is the expected size of the output; possibly symbolic
    # outer_stride is the expected strides of the output; possibly symbolic
    y = MySubclass.__tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride)

    # at the __tensor_unflatten__() call-site in PT2, we assert y.shape == outer_size and y.stride() == outer_stride
    # the assert simplifies symbols when there are relationships between outer and inner symbols
    ```
    * Size info needed for `NestedTensor` at least, stride info needed for `DTensor` at least
    * Punting on `outer_storage_offset` because storage_offset handling is horribly broken in PT2 right now
* ~~Add new `__tensor_mark_dynamic__()` to allow overriding the behavior of mark_dynamic on a per-subclass basis~~ (booted to future work)
* ~~Add guards for tensor subclasses by calling `__tensor_flatten__()` in the guard to test equality on `ctx`~~
    * Now handled in #114469
* Next PR: add TENSOR_MATCH guards on inner tensors

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114311
Approved by: https://github.com/ezyang, https://github.com/drisspg, https://github.com/voznesenskym, https://github.com/bdhirsh
2023-12-05 21:09:25 +00:00
Tugsbayasgalan Manlaibaatar
7f49603ed3 Fix https://github.com/pytorch/pytorch/issues/114899 (#114985)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114985
Approved by: https://github.com/ydwu4
2023-12-03 05:24:02 +00:00
Brian Hirsh
c546ca9f80 AOTAutograd: support mutations on buffers that happen during the bw (#114953)
Re-land of https://github.com/pytorch/pytorch/pull/112906

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114953
Approved by: https://github.com/zou3519, https://github.com/drisspg
2023-12-01 23:09:37 +00:00
Brian Hirsh
64ccdd4afb AOTAutograd: keep input mutations in the graph if they are under no_grad, even if they require_grad (#114646)
Quick recap of events:

(1) https://github.com/pytorch/pytorch/pull/111347, which fixed a perf regression in 2.1 compared to 2.0, introduced a correctness problem around input mutations on inputs that require grad that show up in an inference-only graph (the specific case where this can happen is rare and nobody reported the issue, but it was fixed a few weeks later)

(2) That fix happened here: https://github.com/pytorch/pytorch/pull/113584, which makes sure to keep input mutations outside of the graph, so the autograd engine can set metadata properly on them

(3) That in turn caused a slight regression compared to (1), which is what this PR attempts to fix. In particular, code like the below is safe to keep the mutations in the graph for:

```
@torch.compile
def f(x):
    x.mul_(2)

x = torch.ones(2, requires_grad=True).clone()
# x requires_grad, so the input mutation will change some autograd metadata, like the version counter
# However, the mutation is under no_grad, so we don't have to worry about e.g. aliases of x having their .grad_fn fields changed
with torch.no_grad():
    f(x)
```

This particular case is pretty important to the shampoo optimizer code, which is run under `torch.compile`, and mutates parameters (which require grad).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114646
Approved by: https://github.com/zou3519
2023-11-29 04:29:32 +00:00
PyTorch MergeBot
48820c928c Revert "[test] AOTAutograd: support mutations on buffers that happen during th bw (#112906)"
This reverts commit c8974d649d.

Reverted https://github.com/pytorch/pytorch/pull/112906 on behalf of https://github.com/huydhn due to There are lots of failure after this change c8974d649d, this is probably a landrace ([comment](https://github.com/pytorch/pytorch/pull/112906#issuecomment-1831016362))
2023-11-29 00:49:57 +00:00
Brian Hirsh
c8974d649d [test] AOTAutograd: support mutations on buffers that happen during th bw (#112906)
I can hold off on reviews / landing until I talk to Driss and we confirm that we need this for FP8. This PR also needs testing and probably shouldn't land until Tugsuu's input mutation handling [PR](https://github.com/pytorch/pytorch/pull/111046) goes through.

What this PR tries to solve is when you have a model that tries to mutate some nn module state (a buffer), but during the **backward**. It appears that this might be necessary for FP8's delayed scaling.

Today, AOTAutograd will just not realize if you happened to mutate any graph inputs when running the backward pass, and functionalize them away but not realize that they were input mutations. This PR tries to:

(a) detect this situation (input mutations during the backward)

(b) put `copy_()`'s in the graph to properly handle the input mutation when we can. In cases where we can't keep the copy_() in the graph, we just error loudly (I imagine that these cases will be extremely rare, but we can fix them if they ever come up).

This is mostly a prototype for now, not ready for review.

I made this example locally to test out:
```
import torch

class MutatingAutogradFn(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x, buf):
        ctx.save_for_backward(buf)
        return x

    @staticmethod
    def backward(ctx, x_grad):
        buf = ctx.saved_tensors[0]
        buf.add_(x_grad)
        return x_grad * 3, None

class Mod(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.buf = torch.ones(2)

    @torch._dynamo.allow_in_graph
    def backward_mutating_fn(self, x, buf):
        return MutatingAutogradFn.apply(x, buf)

    def forward(self, x):
        tmp = self.backward_mutating_fn(x, self.buf)
        return tmp + self.buf

m = Mod()

x = torch.ones(2, requires_grad=True)
out = m(x)
# After the fw, buf should not have been mutated
print(m.buf)
out.sum().backward()
# bw has run, so buf should now be mutated
print(m.buf)
print(x.grad)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112906
Approved by: https://github.com/ezyang
2023-11-28 23:59:21 +00:00
voznesenskym
ddf1cb7870 AOTAutograd: handle set_(), detect metadata mutations that cancel out (#111554)
This should be enough to get @voznesenskym 's FSDP branch to plumb `set_()` through AOTAutograd properly and have everything properly no-op out. Main changes are:

(1) graph break on `aten::set_.source_Tensor_storage_offset` (we could support it but it isn't needed, seems safer to graph break)

(2) Functionalization: add a "proper" functionalization kernel for `aten::set_.source_Tensor`. The previous one we had was codegen'd and it was wrong (it would just clone() and call set_(), which does not do the right thing). I also manually mark on the `FunctionalTensorWrapper` when a given tensor has been mutated by a `set_()` call.

(3) AOTAutograd: I added a new field, `InputAliasInfo.mutates_storage_metadata`, so we can distinguish between "regular" metadata mutations, and metadata mutations due to `set_()` calls. This is mainly because at runtime, one requires calling `as_strided_()` to fix up metadata, while the other requires calling `set_()`.

(4) Made AOTAutograd's detection for metadata mutations / set_() mutations smarter and detect no-ops (if the storage and metadata are all the same).

I also killed `was_updated()` and `was_metadata_updated()`, and replaced them with (existing) `has_data_mutation() ` and (new) `has_data_mutation()`, which can more accurately distinguish between data-mutation vs. `set_()` calls vs. metadata-mutation

**This PR is still silently correct in one case though**, which I'd like to discuss more. In particular, this example:
```
def f(x):
    x_view = x.view(-1)
    x.set_(torch.ones(2))
    x_view.mul_(2)
    return
```

If you have an input that experiences both a data-mutation **and** a `x_old.set_(x_new)` call, there are two cases:

(a) the data mutation happened on the storage of `x_new`. This case should be handled automatically: if x_new is a graph intermediate then we will functionalize the mutation. If x_new is a different graph input, then we will perform the usual `copy_()` on that other graph input

(b) the data mutation happened on the storage of `x_old`. This is more of a pain to handle, and doesn't currently work. At runtime, the right thing to do is probably something like:
```

def functionalized_f(x):
    x_view = x.view(-1)
    # set_() desugars into a no-op; later usages of x will use x_output
    x_output = torch.ones(2)
    # functionalize the mutation on x_view
    x_view_updated = x.mul(2)
    x_updated = x_view_updated.view(x.shape)
    # x experienced TWO TYPES of mutations; a data mutation and a metatadata mutation
    # We need to return both updated tensors in our graph
    return x_updated, x_output
def runtime_wrapper(x):
    x_data_mutation_result, x_set_mutation_result = compiled_graph(x)
    # First, perform the data mutation on x's old storage
    x.copy_(x_data_mutation_result)
    # Then, swap out the storage of x with the new storage
    x.set_(x_set_mutation_result)
```

There are two things that make this difficult to do though:

(1) Functionalization: the functionalization rule for `set_()` will fully throw away the old `FunctionalStorageImpl` on the graph input. So if there are any mutations to that `FunctionalStorageImpl` later on in the graph, the current graph input won't know about it. Maybe we can have a given `FunctionalTensorWrapper` remember all previous storages that it had, and track mutations on all of them - although this feels pretty complicated.

(2) AOTAutograd now needs to know that we might have *two* graph outputs that correspond to a single "mutated input", which is annoying.

It's worth pointing out that this issue is probably extremely unlikely for anyone to run into - can we just detect it and error? This feels slightly easier than solving it, although not significantly easier. We would still need `FunctionalTensorWrapper` to keep track of mutations on any of its "previous" storages, so it can report this info back to AOTAutograd so we can raise an error.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111554
Approved by: https://github.com/ezyang
ghstack dependencies: #113926
2023-11-28 19:33:35 +00:00
PyTorch MergeBot
3e1abde46d Revert "AOTAutograd: handle set_(), detect metadata mutations that cancel out (#111554)"
This reverts commit a911b4db9d.

Reverted https://github.com/pytorch/pytorch/pull/111554 on behalf of https://github.com/DanilBaibak due to The lower PR in the stack #113926 breaks the internal build ([comment](https://github.com/pytorch/pytorch/pull/111554#issuecomment-1822472206))
2023-11-22 10:13:48 +00:00
voznesenskym
a911b4db9d AOTAutograd: handle set_(), detect metadata mutations that cancel out (#111554)
This should be enough to get @voznesenskym 's FSDP branch to plumb `set_()` through AOTAutograd properly and have everything properly no-op out. Main changes are:

(1) graph break on `aten::set_.source_Tensor_storage_offset` (we could support it but it isn't needed, seems safer to graph break)

(2) Functionalization: add a "proper" functionalization kernel for `aten::set_.source_Tensor`. The previous one we had was codegen'd and it was wrong (it would just clone() and call set_(), which does not do the right thing). I also manually mark on the `FunctionalTensorWrapper` when a given tensor has been mutated by a `set_()` call.

(3) AOTAutograd: I added a new field, `InputAliasInfo.mutates_storage_metadata`, so we can distinguish between "regular" metadata mutations, and metadata mutations due to `set_()` calls. This is mainly because at runtime, one requires calling `as_strided_()` to fix up metadata, while the other requires calling `set_()`.

(4) Made AOTAutograd's detection for metadata mutations / set_() mutations smarter and detect no-ops (if the storage and metadata are all the same).

I also killed `was_updated()` and `was_metadata_updated()`, and replaced them with (existing) `has_data_mutation() ` and (new) `has_data_mutation()`, which can more accurately distinguish between data-mutation vs. `set_()` calls vs. metadata-mutation

**This PR is still silently correct in one case though**, which I'd like to discuss more. In particular, this example:
```
def f(x):
    x_view = x.view(-1)
    x.set_(torch.ones(2))
    x_view.mul_(2)
    return
```

If you have an input that experiences both a data-mutation **and** a `x_old.set_(x_new)` call, there are two cases:

(a) the data mutation happened on the storage of `x_new`. This case should be handled automatically: if x_new is a graph intermediate then we will functionalize the mutation. If x_new is a different graph input, then we will perform the usual `copy_()` on that other graph input

(b) the data mutation happened on the storage of `x_old`. This is more of a pain to handle, and doesn't currently work. At runtime, the right thing to do is probably something like:
```

def functionalized_f(x):
    x_view = x.view(-1)
    # set_() desugars into a no-op; later usages of x will use x_output
    x_output = torch.ones(2)
    # functionalize the mutation on x_view
    x_view_updated = x.mul(2)
    x_updated = x_view_updated.view(x.shape)
    # x experienced TWO TYPES of mutations; a data mutation and a metatadata mutation
    # We need to return both updated tensors in our graph
    return x_updated, x_output
def runtime_wrapper(x):
    x_data_mutation_result, x_set_mutation_result = compiled_graph(x)
    # First, perform the data mutation on x's old storage
    x.copy_(x_data_mutation_result)
    # Then, swap out the storage of x with the new storage
    x.set_(x_set_mutation_result)
```

There are two things that make this difficult to do though:

(1) Functionalization: the functionalization rule for `set_()` will fully throw away the old `FunctionalStorageImpl` on the graph input. So if there are any mutations to that `FunctionalStorageImpl` later on in the graph, the current graph input won't know about it. Maybe we can have a given `FunctionalTensorWrapper` remember all previous storages that it had, and track mutations on all of them - although this feels pretty complicated.

(2) AOTAutograd now needs to know that we might have *two* graph outputs that correspond to a single "mutated input", which is annoying.

It's worth pointing out that this issue is probably extremely unlikely for anyone to run into - can we just detect it and error? This feels slightly easier than solving it, although not significantly easier. We would still need `FunctionalTensorWrapper` to keep track of mutations on any of its "previous" storages, so it can report this info back to AOTAutograd so we can raise an error.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111554
Approved by: https://github.com/ezyang
ghstack dependencies: #113926
2023-11-21 01:52:46 +00:00
Brian Hirsh
cc11c0d11b aot_autograd: keep input mutations on requires_grad=True tensor out of the graph for inference (#113584)
The original behavior of torch.compile w.r.t. input mutations maintains that if an input to a graph was mutated, **and** requires grad, we will keep the input mutation outside of the graph and replay it at runtime.

This is important because, e.g., an input can have outstanding aliases, and mutating the input in eager mode will cause autograd to change the `grad_fn` of all outstanding aliases.

It looks like landing https://github.com/pytorch/pytorch/pull/111347 changed this behavior slightly:
* The linked PR makes it possible for AOTAutograd to go down the inference code path, even if some inputs require grad (because all of the outputs of the graph were seen to not require grad)
* AOTAutograd's logic in the inference code path today is to **always** keep input mutations in the graph.

This PR fixes that regression: regardless of inference vs. training, we should always keep input mutations outside of the graph if the input requires_grad.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113584
Approved by: https://github.com/tugsbayasgalan
ghstack dependencies: #113267, #113416
2023-11-15 19:55:47 +00:00
PyTorch MergeBot
0e6b6a2483 Revert "AOTAutograd: handle set_(), detect metadata mutations that cancel out (#111554)"
This reverts commit 3afb4e5cf7.

Reverted https://github.com/pytorch/pytorch/pull/111554 on behalf of https://github.com/clee2000 due to the xla failure is real sorry, log classifier is showing the wrong line ([comment](https://github.com/pytorch/pytorch/pull/111554#issuecomment-1809177978))
2023-11-13 21:46:57 +00:00
Brian Hirsh
3afb4e5cf7 AOTAutograd: handle set_(), detect metadata mutations that cancel out (#111554)
This should be enough to get @voznesenskym 's FSDP branch to plumb `set_()` through AOTAutograd properly and have everything properly no-op out. Main changes are:

(1) graph break on `aten::set_.source_Tensor_storage_offset` (we could support it but it isn't needed, seems safer to graph break)

(2) Functionalization: add a "proper" functionalization kernel for `aten::set_.source_Tensor`. The previous one we had was codegen'd and it was wrong (it would just clone() and call set_(), which does not do the right thing). I also manually mark on the `FunctionalTensorWrapper` when a given tensor has been mutated by a `set_()` call.

(3) AOTAutograd: I added a new field, `InputAliasInfo.mutates_storage_metadata`, so we can distinguish between "regular" metadata mutations, and metadata mutations due to `set_()` calls. This is mainly because at runtime, one requires calling `as_strided_()` to fix up metadata, while the other requires calling `set_()`.

(4) Made AOTAutograd's detection for metadata mutations / set_() mutations smarter and detect no-ops (if the storage and metadata are all the same).

I also killed `was_updated()` and `was_metadata_updated()`, and replaced them with (existing) `has_data_mutation() ` and (new) `has_data_mutation()`, which can more accurately distinguish between data-mutation vs. `set_()` calls vs. metadata-mutation

**This PR is still silently correct in one case though**, which I'd like to discuss more. In particular, this example:
```
def f(x):
    x_view = x.view(-1)
    x.set_(torch.ones(2))
    x_view.mul_(2)
    return
```

If you have an input that experiences both a data-mutation **and** a `x_old.set_(x_new)` call, there are two cases:

(a) the data mutation happened on the storage of `x_new`. This case should be handled automatically: if x_new is a graph intermediate then we will functionalize the mutation. If x_new is a different graph input, then we will perform the usual `copy_()` on that other graph input

(b) the data mutation happened on the storage of `x_old`. This is more of a pain to handle, and doesn't currently work. At runtime, the right thing to do is probably something like:
```

def functionalized_f(x):
    x_view = x.view(-1)
    # set_() desugars into a no-op; later usages of x will use x_output
    x_output = torch.ones(2)
    # functionalize the mutation on x_view
    x_view_updated = x.mul(2)
    x_updated = x_view_updated.view(x.shape)
    # x experienced TWO TYPES of mutations; a data mutation and a metatadata mutation
    # We need to return both updated tensors in our graph
    return x_updated, x_output
def runtime_wrapper(x):
    x_data_mutation_result, x_set_mutation_result = compiled_graph(x)
    # First, perform the data mutation on x's old storage
    x.copy_(x_data_mutation_result)
    # Then, swap out the storage of x with the new storage
    x.set_(x_set_mutation_result)
```

There are two things that make this difficult to do though:

(1) Functionalization: the functionalization rule for `set_()` will fully throw away the old `FunctionalStorageImpl` on the graph input. So if there are any mutations to that `FunctionalStorageImpl` later on in the graph, the current graph input won't know about it. Maybe we can have a given `FunctionalTensorWrapper` remember all previous storages that it had, and track mutations on all of them - although this feels pretty complicated.

(2) AOTAutograd now needs to know that we might have *two* graph outputs that correspond to a single "mutated input", which is annoying.

It's worth pointing out that this issue is probably extremely unlikely for anyone to run into - can we just detect it and error? This feels slightly easier than solving it, although not significantly easier. We would still need `FunctionalTensorWrapper` to keep track of mutations on any of its "previous" storages, so it can report this info back to AOTAutograd so we can raise an error.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111554
Approved by: https://github.com/ezyang
2023-11-13 16:39:25 +00:00
Kaichao You
958f755a0e [FX][CodeGen] Make sure fx code is valid in python (#113345)
This PR fixes two cases when fx generated code is invalid in python (syntax error):

1. multiple type annotation in one line: `var1: annotation1, var2: annotation2 = function_call()`
2. invalid type annotation for scalars like `var1: f32[] = function_call()`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113345
Approved by: https://github.com/ezyang
2023-11-10 21:12:16 +00:00
eellison
325e0fdfdd Enable masked_scatter_backward for inductor (#109642)
masked_scatter_backward was previously implemented as a
CompositeExplicitAutograd, which involved a decomp that calls
masked_select, and masked_select in general produces data-dependent
shapes that inductor doesn't support. But masked_scatter_backward
reshapes the return value of masked_select such that the end result has
a static shape again.

I have converted masked_scatter_backward into an aten op to avoid this
issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109642
Approved by: https://github.com/ezyang
ghstack dependencies: #108170
2023-11-09 01:27:57 +00:00
Tugsbayasgalan Manlaibaatar
84d64d72d6 Persist copy_ in training graph for inputs that don't require grad (#111046)
In this PR, we try to keep the input mutations in the forward graph IFF input mutation is data mutation and not metadata mutation and doesn't require grad. This is for optimizing inductor training graphs. (For more details: https://github.com/pytorch/pytorch/issues/109240)

We keep the input mutation in the graph by wrapping the original callable in a wrapper function where in the end we add input.copy_(updated_input) call which is then traced via make_fx. Previously, this was only enabled for forward-only path but unconditionally disabled for joint graph.

Another caveat is that when we are tracing through tensor subclasses, we won't allow any input mutations to be preserved in the graph. The reason is that it makes the code logic quite ugly for no obvious performance improvement.

Most of the changes in this PR are mechanical and I didn't have to make any change to the partitioner. Previously forward/backward heavily relied on metadata field `num_mutated_inps` to figure out whether something is returned as extra output or not. But now since we keep some mutations in the graph, we need to propogate something similar to `num_mutated_inps - num_graph_handled_inps`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111046
Approved by: https://github.com/ezyang, https://github.com/bdhirsh
2023-11-09 00:40:29 +00:00
Edward Z. Yang
1f3fa13f0a Handle unbacked SymInt sized outputs in AOTAutograd (#113159)
Thanks aakhundov for constructing the test case. This PR was constructed by running the failing test case, and then fixing problems until we got all the way to the end. There are a few distinct fixes:

* AOTAutograd performs equality tests on tensor metadata to determine if a metadata mutation had occurred. If we test i0 vs i1, we should report these are NOT equal, since obviously we have somehow resized the tensor from i0 to i1 (even if, on a particular run, it is possible i0 == i1).
* There's a sketchy fix for `test_aot_autograd_exhaustive_matmul_cpu_float32` where we check if the output shape equals the tangent shape. Unfortunately, the same `definitely_true` treatment does not work here, it still fails on the example. I piled an extra sketchy fix on top of it, where I just try my best to avoid doing the view. Maybe we should have some sort of logging here.
* Partitioner needs to get out a size for unbacked SymInt when partitioning. I just feed it a random heuristic value in this case, similar to how we've been dealing with this in Inductor.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113159
Approved by: https://github.com/aakhundov, https://github.com/bdhirsh
2023-11-08 04:28:38 +00:00
PaliC
542fa4a2e7 Revert "Revert "Use OpOverload instead of OpOverloadPacket for size/s… (#113058)
Revert "Revert "Use OpOverload instead of OpOverloadPacket for size/stride/etc slots (#112119)""

This reverts commit a1d1b73a7c.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113058
Approved by: https://github.com/izaitsevfb
2023-11-06 19:38:49 +00:00
PyTorch MergeBot
a1d1b73a7c Revert "Use OpOverload instead of OpOverloadPacket for size/stride/etc slots (#112119)"
This reverts commit 2337d8d062.

Reverted https://github.com/pytorch/pytorch/pull/112119 on behalf of https://github.com/PaliC due to still breaking trt tests :( refer to diff ([comment](https://github.com/pytorch/pytorch/pull/112119#issuecomment-1795496395))
2023-11-06 17:01:50 +00:00
Edward Z. Yang
2337d8d062 Use OpOverload instead of OpOverloadPacket for size/stride/etc slots (#112119)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112119
Approved by: https://github.com/yanboliang
2023-11-03 13:54:41 +00:00
Kazuaki Ishizaki
9089242048 Fix typo under test directory (#112346)
This PR fixes typo in comments and messages under `test` directory. This PR also fixes related typo in messages under `torch` directory.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112346
Approved by: https://github.com/kit1980, https://github.com/ezyang
2023-11-03 07:53:33 +00:00
PyTorch MergeBot
25e17f3522 Revert "Use OpOverload instead of OpOverloadPacket for size/stride/etc slots (#112119)"
This reverts commit dd24e92949.

Reverted https://github.com/pytorch/pytorch/pull/112119 on behalf of https://github.com/ZainRizvi due to Breaking internal tests. See D50912326 ([comment](https://github.com/pytorch/pytorch/pull/112119#issuecomment-1791072363))
2023-11-02 16:32:25 +00:00
Edward Z. Yang
a1ab22b81d Reland "Trigger specialization when you call size()/stride() from C++ (#111935)" (#112605)
This reverts commit 22221c6d60.

Differential Revision: [D50886564](https://our.internmc.facebook.com/intern/diff/D50886564)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112605
Approved by: https://github.com/voznesenskym
2023-11-02 13:27:31 +00:00
Edward Z. Yang
dd24e92949 Use OpOverload instead of OpOverloadPacket for size/stride/etc slots (#112119)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112119
Approved by: https://github.com/yanboliang
2023-11-01 18:26:01 +00:00