Commit Graph

247 Commits

Author SHA1 Message Date
Yanbo Liang
29523aa113 [Dynamo][autograd.Function] Relax backward speculation strict mode a bit (#146571)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146571
Approved by: https://github.com/zou3519
2025-02-11 05:39:00 +00:00
Animesh Jain
487400f47f [dynamo] Support functools.partial variables through inspect.signature (#146339)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146339
Approved by: https://github.com/jansel
ghstack dependencies: #146322, #146116
2025-02-04 04:39:39 +00:00
Harmen Stoppels
01554c7b5a fix incorrect literal strings / accidental tuples (#146037)
* `expr,` is short for `(expr,)`
* literal strings over multiple lines need to escape the newline `\` or use `(...)`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146037
Approved by: https://github.com/Skylion007
2025-02-03 15:08:11 +00:00
Animesh Jain
31fb691782 [dynamo] Graph break on tensor.retain_grad (#146214)
Fixes https://github.com/pytorch/pytorch/issues/146212

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146214
Approved by: https://github.com/jansel
ghstack dependencies: #146062, #146198, #146258
2025-02-02 03:12:36 +00:00
Aaron Orenstein
f3120f6d26 Remove incorrect BuiltinVariable.call_hasattr() (#145551)
BuiltinVariable.call_hasattr() overrides the base class - but actually behaves differently. The base is `obj.call_hasattr(tx, attr)` but BuiltinVariable's version is `<unused>.call_hasattr(tx, obj, attr)`.

The BuiltinVariable version is used as a pattern from `call_self_handler()` for `BuiltinVariable(hasattr)`. I think the other version is just used for internal `hasattr(obj, name)` so I renamed that one to `call_obj_hasattr`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145551
Approved by: https://github.com/anijain2305
2025-01-30 22:21:19 +00:00
Aaron Orenstein
a79100ab11 PEP585 update - torch/_dynamo (#145105)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145105
Approved by: https://github.com/bobrenjc93
2025-01-18 20:47:11 +00:00
Tom Ritchford
dc23f1944a Remove unused Python variables in torch/[_-a]* (#133492)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133492
Approved by: https://github.com/albanD
2024-12-12 17:39:14 +00:00
PyTorch MergeBot
5c97ac9721 Revert "Remove unused Python variables in torch/[_-a]* (#133492)"
This reverts commit fda975a7b3.

Reverted https://github.com/pytorch/pytorch/pull/133492 on behalf of https://github.com/clee2000 due to Sorry, I need to revert this in order to revert something else.  The only thing you need to do is rebase and remerge ([comment](https://github.com/pytorch/pytorch/pull/133492#issuecomment-2536635516))
2024-12-11 17:29:12 +00:00
Tom Ritchford
fda975a7b3 Remove unused Python variables in torch/[_-a]* (#133492)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133492
Approved by: https://github.com/albanD
2024-12-10 21:48:44 +00:00
Ryan Guo
2da98d9757 [dynamo] Support is comparison for symnodes (#140754)
Fixes #109504.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140754
Approved by: https://github.com/williamwen42
2024-11-19 00:19:33 +00:00
Michael Lazos
8b2e3855a9 Make size a property with an assertion (#139794)
Fixes https://github.com/pytorch/pytorch/issues/120568

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139794
Approved by: https://github.com/williamwen42
2024-11-09 03:39:41 +00:00
Michael Lazos
ea0f60ecfa [Dynamo] allow dynamic callables on tensor variables (#137940)
Fixes https://github.com/pytorch/pytorch/issues/134844

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137940
Approved by: https://github.com/williamwen42
2024-11-08 23:49:34 +00:00
Yidi Wu
6734cb7bf2 [hop free symbols] refactor tensor.to_list implementation to call wrap_fx_proxy. (#139663)
Refactoring only. Previously, we manually cal SymNodeVariable.create, now we handle it with wrap_fx_proxy. This unifies the handling of operations that produce symints in wrap_fx_proxy.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139663
Approved by: https://github.com/zou3519
ghstack dependencies: #138345, #138428, #138558, #138737, #138559
2024-11-05 20:19:09 +00:00
Ryan Guo
693a0a1bd4 [dynamo][NFC] Rename mutable_local and add documentation (#139339)
This patch addresses the renaming part of #133027, specifically, it
renames the following and adds documentation for relevant classes.
1. `VariableTracker.mutable_local` to `mutation_type`
2. `MatableLocal `to `ValueMutationNew`
3. `MutableSideEffects `to `ValueMutationExisting`
4. `MutableLocalSource` to `SourceType`
5. `MutableLocalSource.Local` to `New`

Note that (2), (3) and (5) are mainly to bring consistency between them
and `AttributeMutationNew`, `AttributeMutationExisting`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139339
Approved by: https://github.com/jansel, https://github.com/mlazos, https://github.com/anijain2305
2024-11-05 19:11:41 +00:00
Tomasz Bohutyn
9dc5851f5d handle more devices in method_type method of TensorVariable (#138078)
Fixes #138077

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138078
Approved by: https://github.com/jgong5, https://github.com/ezyang
2024-11-05 18:19:52 +00:00
Ryan Guo
30a83ca991 [dynamo] Improve codegen for DataPtrVariable and fix tensor reference issue (#139487)
This addresses
https://github.com/pytorch/pytorch/pull/137677/files#r1799836499, which
had to set `allow_cache=False` for codegen on `DataPtrVariable.base`,
which is a `TensorVariable`, otherwise we observe failure of
`test_no_grad_copy` when testing with Dynamo.

I've seen `test_no_grad_copy` failing a few times, and every single time
it's related to cyclic reference, my best guess is the cyclic reference
holds some tensor object longer in memory than necessary, preventing the
optimization introduced in #11165.

This patch makes `OutputGraph.cleanup()` more aggressive by clearing out
all fields that might reference a `VariableTracker`. As a result, we can
remove the aforementioned `allow_cache=False`, which helps generate
better code (e.g., in the case of `test_no_grad_copy`, it skipped generating
a redundant graph whose only op is returning the input tensor; instead we just
generate a single `LOAD_FAST`).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139487
Approved by: https://github.com/jansel, https://github.com/aakhundov
2024-11-04 19:14:06 +00:00
PyTorch MergeBot
b6b9596607 Revert "[dynamo] Fix constant propagation in builtins and UserClasses (#131354)"
This reverts commit 44257c063e.

Reverted https://github.com/pytorch/pytorch/pull/131354 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but it seems to break some internal tests ([comment](https://github.com/pytorch/pytorch/pull/131354#issuecomment-2451050605))
2024-11-01 00:13:20 +00:00
Tom Ritchford
44257c063e [dynamo] Fix constant propagation in builtins and UserClasses (#131354)
* Fixes https://github.com/pytorch/pytorch/issues/118675
* Replaces https://github.com/pytorch/pytorch/pull/118994

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131354
Approved by: https://github.com/jansel, https://github.com/anijain2305
2024-10-30 12:47:20 +00:00
Simon Fan
fd9f4e6770 Back out "[compiled autograd] tls access helpers (#138061)" and Back out "[compiled autograd] Compiled autograd configs in TLS (#137821)" (#139086)
Summary:
Original commit changeset: 9bf80c1492d7

Original Phabricator Diff: D64796226

Original commit changeset: aa1d9ef8f6e6

Original Phabricator Diff: D64796212

Differential Revision: D65072644

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139086
Approved by: https://github.com/malfet
2024-10-28 23:37:05 +00:00
Joel Schlosser
14a17ad630 Elide calls to is_nested in Dynamo-traced graphs (#138841)
Before this PR, calling `is_nested` in-graph would result in graph code like the following:
```python
  class GraphModule(torch.nn.Module):
      def forward(self, L_nt_: "f64[3, s1, 5]", s1: "Sym(s1)"):
          l_nt_ = L_nt_

          # Note this useless line!
          getattr_1 = l_nt_.is_nested;  getattr_1 = None

          add: "f64[3, s1, 5]" = l_nt_ + 2;  l_nt_ = None
          return (add,)
```

This PR follows what is done for `is_sparse` / `is_quantized`: store it onto `TensorVariable` and have `getattr` calls to `is_nested` return the stored value as a constant. This removes the useless line above from the graph. Note that guarding is handled through tensor type check guards, so no need to guard on `is_nested` status.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138841
Approved by: https://github.com/soulitzer
2024-10-26 15:03:32 +00:00
Bob Ren
500b2bc781 Have as_tensor always return a float64 tensor in dynamo (#138598)
As discussed with @ezyang, this set of diffs are extracting fixes to problems discovered to flipping `specialize_float=False` in https://github.com/pytorch/pytorch/pull/137782. Since these codepaths are exercised in existing tests, I'm going to bias towards shipping speed and put these up with the primary test plan as the global CI. These code paths are all tested via existing tests when `specialize_float=False` and it feels a bit wonky to add more gated tests that only test behavior when this flag is True, especially since these code paths are already covered. That being said, I'm happy to add individual tests if reviewers insist or have a different POV.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138598
Approved by: https://github.com/ezyang
ghstack dependencies: #138595
2024-10-24 20:50:28 +00:00
Simon Fan
5a13282c75 [compiled autograd] tls access helpers (#138061)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138061
Approved by: https://github.com/yf225
ghstack dependencies: #137953, #137821
2024-10-22 08:03:52 +00:00
Simon Fan
49fa437097 [compiled autograd] Compiled autograd configs in TLS (#137821)
Multithreaded doesn't work yet, this adds python side TLS only for the python side state

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137821
Approved by: https://github.com/jansel, https://github.com/yf225
ghstack dependencies: #137953
2024-10-22 08:03:52 +00:00
Tom Ritchford
e1c4548441 [dynamo] Simplify creation of VariableTrackers (#135714)
## `VariableTracker::build()` hides the Builders

### The problem

In the current code, creating a `VariableTracker` involves choosing one of two `Builder` classes and either calling a method, or calling a constructor that creates an object that you immediately call, [like this](083c9149b7/torch/_dynamo/variables/functions.py (L761-L768)).

Variations on this code are repeated in many places.

More, the `Builder` classes have a lot of dependencies, so they have to be loaded late in the whole import process to avoid circular imports, so they end up being repeatedly imported at local scope.

### The solution

In this commit, the import from `builder` and the logic of choosing and calling the Builder class are hidden in a single static factory method, `VariableTracker.build()`, easier to reason about and to import.

This commit net lowers the total lines of code by over 150 lines by removing repetitive logic and unnecessary local imports.

**CHANGES:** Originally the name of the static method was `VariableTracker.create()` but a static method on a derived class, `LazyVariableTracker.create()` now exists with a different signature that's irreconcilable, so the new static method was renamed to `VariableTracker.build()`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135714
Approved by: https://github.com/jansel
2024-10-18 09:36:46 +00:00
PyTorch MergeBot
361f42bc42 Revert "[compiled autograd] Compiled autograd configs in TLS (#137821)"
This reverts commit 9aba0b91c8.

Reverted https://github.com/pytorch/pytorch/pull/137821 on behalf of https://github.com/wdvr due to Reverting this for now, it is failing test_public_bindings in trunk ([comment](https://github.com/pytorch/pytorch/pull/137821#issuecomment-2417351788))
2024-10-16 16:38:29 +00:00
Simon Fan
9aba0b91c8 [compiled autograd] Compiled autograd configs in TLS (#137821)
Multithreaded doesn't work yet, this adds python side TLS only for the python side state

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137821
Approved by: https://github.com/jansel, https://github.com/yf225
ghstack dependencies: #137953
2024-10-16 09:28:32 +00:00
Adnan Akhundov
809ff3b274 Add host-side Triton TMA support to Dynamo (#137677)
This adds Dynamo tracing support for the host-side Triton TMA API (see `create_2d_tma_descriptor` calls on the host in the [Triton tutorial](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#sphx-glr-getting-started-tutorials-09-persistent-matmul-py)). A few notes:

- Here we assume the availability of the host-side TMA API added to upstream Triton in https://github.com/triton-lang/triton/pull/4498. As of time of writing, this is not a part of the PT2 OSS Triton pin (although back-ported internally). OSS Triton pin update should be done in December 2024.
- To capture the chain of calls `t.data_ptr() --> create_{1d,2d}_tma_descriptor(ptr, ...) --> kernel[grid](tma_desc, ...)`, we add three new variable trackers: `DataPtrVariable`, `CreateTMADescriptorVariable` (for the function), `TMADescriptorVariable` (for TMA descriptor object). This is to maintain the path back from the Triton kernel to the Tensor from which the TMA descriptor has been created.
- The newly introduced variables have `reconstruct` methods used in case of graph breaks.
- The `tma_descriptor_metadata` extracted from the captured `create_{1d,2d}_tma_descriptor` calls is propagated through the HOPs in Dynamo and AOTAutograd to be used by the downstream compiler (e.g., Inductor). See the unit tests for how the captured HOP arguments look like.
- In the Dynamo-captured fx graph, we replace the TMA descriptor arguments of the Triton kernel by the underlying Tensors, to be able to track the input/output relationships in terms of Tensors.
- In the Triton kernel mutation analysis pass (in AOTAutograd), we use the `tt.experimental_descriptor_store` TTIR op to detect mutations of the underlying tensors via TMA descriptors. So that downstream AOTAutograd can perform functionalizations as required.
- JIT Inductor and AOT Inductor support will be implemented in follow-up PRs.

Differential Revision: [D64404928](https://our.internmc.facebook.com/intern/diff/D64404928)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137677
Approved by: https://github.com/zou3519
2024-10-16 02:18:48 +00:00
Richeek Das
080f02ac7a [dynamo] do not raise an unimplemented error with boolean masking setitem (#134902)
Cudagraph breaks on boolean masking setitem, however the code runs fine. There is no need to raise an unimplemented error here, since it already warns that its an incompatible op.

Fixes #134241

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134902
Approved by: https://github.com/jansel, https://github.com/henrylhtsang
2024-10-10 19:11:40 +00:00
Michael Lazos
d5785d4295 [Dynamo] Handle torch function subclass/mode dispatch on generic tensor methods (#137119)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137119
Approved by: https://github.com/williamwen42, https://github.com/anijain2305
ghstack dependencies: #137114, #137115, #137116, #137117, #137120, #137227
2024-10-09 02:29:40 +00:00
Michael Lazos
0a304d9048 [Dynamo] Handle extracted unbound tensor methods (#137227)
fixes2

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137227
Approved by: https://github.com/williamwen42, https://github.com/anijain2305
ghstack dependencies: #137114, #137115, #137116, #137117, #137120
2024-10-09 02:29:40 +00:00
Michael Lazos
27dee935af [Dynamo] Ensure torch function modes are dispatched on builtin ops (#137117)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137117
Approved by: https://github.com/yanboliang, https://github.com/williamwen42
ghstack dependencies: #137114, #137115, #137116
2024-10-09 02:29:40 +00:00
PyTorch MergeBot
2d18c2d5e7 Revert "[Dynamo] Ensure torch function modes are dispatched on builtin ops (#137117)"
This reverts commit 941be418d8.

Reverted https://github.com/pytorch/pytorch/pull/137117 on behalf of https://github.com/huydhn due to The top of the stack has been reverted but it leaves trunk in a broken state, so I try to revert the rest of the stack ([comment](https://github.com/pytorch/pytorch/pull/137114#issuecomment-2400765603))
2024-10-08 20:33:17 +00:00
PyTorch MergeBot
76c5bdd2cc Revert "[Dynamo] Handle extracted unbound tensor methods (#137227)"
This reverts commit 14eabd6915.

Reverted https://github.com/pytorch/pytorch/pull/137227 on behalf of https://github.com/malfet due to Need to revert to be able to revert https://github.com/pytorch/pytorch/pull/136910 ([comment](https://github.com/pytorch/pytorch/pull/137227#issuecomment-2400406384))
2024-10-08 17:12:41 +00:00
PyTorch MergeBot
c88c0e6c65 Revert "[Dynamo] Handle torch function subclass/mode dispatch on generic tensor methods (#137119)"
This reverts commit d255b34c0a.

Reverted https://github.com/pytorch/pytorch/pull/137119 on behalf of https://github.com/malfet due to Need to revert to be able to revert https://github.com/pytorch/pytorch/pull/136910 ([comment](https://github.com/pytorch/pytorch/pull/137119#issuecomment-2400401262))
2024-10-08 17:09:26 +00:00
Michael Lazos
d255b34c0a [Dynamo] Handle torch function subclass/mode dispatch on generic tensor methods (#137119)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137119
Approved by: https://github.com/williamwen42
ghstack dependencies: #137114, #137115, #137116, #137117, #137120, #137227
2024-10-07 18:55:26 +00:00
Michael Lazos
14eabd6915 [Dynamo] Handle extracted unbound tensor methods (#137227)
fixes2

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137227
Approved by: https://github.com/williamwen42
ghstack dependencies: #137114, #137115, #137116, #137117, #137120
2024-10-07 18:55:26 +00:00
Michael Lazos
941be418d8 [Dynamo] Ensure torch function modes are dispatched on builtin ops (#137117)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137117
Approved by: https://github.com/yanboliang, https://github.com/williamwen42
ghstack dependencies: #137114, #137115, #137116
2024-10-07 18:55:26 +00:00
Salman Mohammadi
48c18ff850 [dynamo] Added support for tensor's is_inference method (#136450)
Fixes #135439

This PR adds support for the `is_inference` method on torch tensors which successfully compiles the following example fn without graph breaks:
```python
def fn_simple(x):
    if x.is_inference():
        return x.sum()
    else:
        return x.min()
```

I've also tried to add guards on the tensor to guard against  `is_inference`. I wasn't 100% sure where these should go so please don't hesitate to correct me.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136450
Approved by: https://github.com/ezyang
2024-09-27 09:15:07 +00:00
PyTorch MergeBot
9223c16208 Revert "Fix constant propagation in builtins and UserClasses (#131354)"
This reverts commit dd4a51b39a.

Reverted https://github.com/pytorch/pytorch/pull/131354 on behalf of https://github.com/atalman due to Breaks torchrec tests ([comment](https://github.com/pytorch/pytorch/pull/131354#issuecomment-2375417145))
2024-09-25 23:01:03 +00:00
Tom Ritchford
dd4a51b39a Fix constant propagation in builtins and UserClasses (#131354)
* Fixes https://github.com/pytorch/pytorch/issues/118675
* Replaces https://github.com/pytorch/pytorch/pull/118994

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131354
Approved by: https://github.com/jansel, https://github.com/anijain2305
2024-09-25 13:03:40 +00:00
Will Feng
a815611db9 [Traceable FSDP2][Partitioner] Must save AC output if output has a backward hook (#135727)
If node is AC region output and has a backward hook on it, we intentionally choose to save it.
This is to work around circular dependencies in Traceable FSDP2+AC.
Example:
```
out = fully_shard(utils.checkpoint(module))(x)
norm_out = layer_norm(out)
```
and there is a circular dependency:
1. In backward, grad_input of layer_norm aka. `out_grad` is actually dependent on `out`.
2. `out` depends on `out`'s backward hook created by FSDP2 (which does all-gather for `module` weights) in order to be recomputed.
3. `out`'s FSDP2 backward hook, as is the case for all eager backward hooks, depends on `out_grad`  -> circular dependency with (1)!

Solution: check whether `out` has a backward hook, and if so, intentionally save `out` in forward graph outputs. With this, we can break the above circular dependency.

----

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135727
Approved by: https://github.com/Chillee
2024-09-14 08:45:58 +00:00
Tom Ritchford
2c99f17a32 Implement VariableTracker.python_type() (#134215)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134215
Approved by: https://github.com/amjames, https://github.com/jansel
2024-09-05 16:35:47 +00:00
Michael Lazos
4e71418566 [dynamo] rewrite addcmul_ to remove graph break (#134168)
Context: Adding support for the beta parameters to be tensors

Details: Similarly to the previous two PRs addcmul_ is used with the tensor betas as the value argument. When this occurs, an item() call is invoked in the aten op. To avoid this graph break, addcmul_ is decomposed into its constrituent ops to avoid this.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134168
Approved by: https://github.com/anijain2305
ghstack dependencies: #134166, #134167
2024-08-31 10:24:39 +00:00
Oguz Ulgen
6e79932543 Add basic mypy annotations to dynamo (#132415)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132415
Approved by: https://github.com/XuehaiPan, https://github.com/jamesjwu
2024-08-04 18:43:36 +00:00
PyTorch MergeBot
3558a8cf4a Revert "Add basic mypy annotations to dynamo (#132415)"
This reverts commit 71e22e0959.

Reverted https://github.com/pytorch/pytorch/pull/132415 on behalf of https://github.com/ZainRizvi due to Sorry, this PR has entered a weird state in the diff train. Trying to revert it to skip it, and then we can try relanding it ([comment](https://github.com/pytorch/pytorch/pull/132415#issuecomment-2267631785))
2024-08-04 18:39:29 +00:00
Chen Haifeng
50ed6ce277 Support built-in id function for TensorVariable on parameters (#130100)
Fixes #130087

This patch tries to provide a built-in id function implementation for TensorVariable when the id function is called on tensors like module parameters. The id function call on intermediate tensors is not supported.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130100
Approved by: https://github.com/anijain2305
2024-08-02 01:19:25 +00:00
Oguz Ulgen
71e22e0959 Add basic mypy annotations to dynamo (#132415)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132415
Approved by: https://github.com/XuehaiPan, https://github.com/jamesjwu
2024-08-01 20:14:25 +00:00
Yiming Zhou
ee09d066d3 [dynamo] Add line number to _warn_capture_scalar_outputs() (#132333)
Fixes #127667.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132333
Approved by: https://github.com/anijain2305
2024-08-01 16:11:21 +00:00
Xuehai Pan
e74ba1b34a [BE][Easy][15/19] enforce style for empty lines in import segments in torch/_d*/ (#129767)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129767
Approved by: https://github.com/anijain2305
2024-07-31 21:18:11 +00:00
Brian Hirsh
8bb9aa93a7 dynamo: mutations on .data should be invisible to autograd (#131403)
Fixes https://github.com/pytorch/pytorch/issues/121353

our handle for `.data` in dynamo today basically just converts `y = x.data` into `y = x.detach()`. The semantics of these two ops are not quite the same, because:

(1) any future mutations on `x.data` will be fully ignored by autograd
(2) any mutations on `x.detach()` will bump x's version counter

the linked model does a .data mutation that is hidden from autograd in eager, but ends up erroring during AOTDispatcher tracing.

I updated dynamo's handling so that:

(1) when dynamo sees a call to `getattr(tensor, "data")` and calls `.detach()` we set a flag on the returned `TensorVariable` indicating it came from `.data`

(2) on any tensor method that we call with an input `TensorVariable` with this flag turned on, we proxy autograd's `preserve_version_counter` logic into the graph, to properly reset the VC after the op is run.

One thing to note is that I don't actually do this on every op that we pass the tensor to: I only do it for tensor methods that appear to be mutations (by checking for a trailing underscore). My thought was that:

(1) I didn't want to do this for **every** op that you pass `y` into, since that will e.g. triple the number of nodes in the graph, and could cause compile time regressions if you use .data

(2) this situation is pretty rare in general, and I'm hoping that "tensor method mutations" cover most reasonable mutation cases. If we manage to miss a case, you will get a loud error during tracing anyway, so there is not a safety issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131403
Approved by: https://github.com/anijain2305, https://github.com/zou3519
2024-07-26 14:22:20 +00:00