Commit Graph

817 Commits

Author SHA1 Message Date
Aaron Gokaslan
bfee141666 [BE]: Apply ruff PERF403 to use dict comprehensions more often (#149257)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149257
Approved by: https://github.com/jansel
2025-03-16 23:52:58 +00:00
Brian Hirsh
3646d4dbc8 [partitioner] always ban compiler-driven recompute of collectives by default (#147561)
This should fix the hang in https://fb.workplace.com/groups/1075192433118967/permalink/1603268720311333/

The argument here is that:

(1) in general, it is not safe for the partitioner to sometimes choose to recompute collectives in the backward. Why? If we are running a distributed job, where many ranks are compiling at the same time, we need every rank to make a consistent decision about which collectives are recomputed for backward. If we let each compiler instance make its own choice without any cross-rank communication, they can make different choices and cause NCCL hangs (see the link above)

(2) later on, we'll want an `spmd_mode` flag that causes the compiler to issue collectives and communicate info across ranks. Once we have such a config, then turning it on should make it safe for the partitioner to potentially choose to recompute collectives (and agree on the binary "recompute-or-save" choice across all ranks)

(3) even without an `spmd_mode`, users can override this choice by using `torch.utils.checkpoint()` in their user code. User checkpointing generally always overrides the partitioner, and this should be safe because we expect the user to apply checkpointing consistently across ranks

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147561
Approved by: https://github.com/zou3519
2025-03-13 03:36:13 +00:00
Brian Hirsh
621dadd4ca partitioner: when materializing unbacked tensor intermediates, apply hint to symbol, not expr (#144097)
Fixes https://github.com/pytorch/pytorch/issues/144095

open to suggestions: the `hint_int(..., fallback=...)` API feels like a bit of a footgun, because:

(1) we use the same guess for every unbacked symint (both symbols, and compound expressions)
(2) the user may have established some relationship between some unbacked symints that we are not taking into account.

I'm not sure how real of an issue (2) is - is it common to e.g. generate two unbacked symints, and then add a runtime assert that they are unequal?

Instead I did something simpler that's just enough to fix the linked issue: if we have a sympy expression containing an unbacked symbol (e.g. `u0 + 1`), then the partitioner will now fill in the symbol with our guess instead of the expression (plugging in `u0=4096` gets us 4097). This was important for an internal custom op, that had some logic like this:
```
def custom_op(x: [u0], y: [u0 + 1]):
    assert x.shape[0] = y.shape[0] - 1
    ...
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144097
Approved by: https://github.com/laithsakka
2025-03-11 02:11:57 +00:00
Simon Fan
666508eb17 [aot cache][ca] remove restriction on caching ca's aot inference graph (#148491)
but still can't cache CA's aot inference graph yet: the CA functional ops aren't serializable

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148491
Approved by: https://github.com/jamesjwu
ghstack dependencies: #148381
2025-03-08 06:08:26 +00:00
Simon Fan
c16cd25cf5 [ca] remove compiled_autograd_tracing (#148381)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148381
Approved by: https://github.com/jansel
2025-03-08 06:08:26 +00:00
Luca Wehrstedt
f80aad62fa Improve Pareto frontier plot for AutoAC (#148678)
This was added in https://github.com/pytorch/pytorch/pull/126320. It's a very nice feature, which can be used to predict memory usage for different budget values.

However, it had some limitations, notably in terms of resolution (it only sampled 21 points across the whole range thus missed many threshold values) and in distributed settings.

Here I fix those by using recursive binary searches to identify all thresholds (up to a resolution of 1e-3, which can be made configurable) and output them in SVG (to be able to discern different points), plus I add the rank to the filename and store it in a user-define directory.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148678
Approved by: https://github.com/Chillee, https://github.com/fmassa
2025-03-07 13:22:29 +00:00
Animesh Jain
d43c6f0033 [invoke_subgraph] Run joint passes on the hop graphs (#139325)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139325
Approved by: https://github.com/bdhirsh, https://github.com/zou3519
ghstack dependencies: #147559
2025-03-03 23:38:14 +00:00
Animesh Jain
eb9c127341 [dynamo][optimizers] Install ID_GUARDED tensors into the Fx graph (#147824)
Earlier, with inline flag we were lifting id-guarded tensors to the inputs to the Fx graph. But this offers no benefit. Main idea behind lifting parameters as inputs was to reuse the compilation units across many instances of the nn-module. However, if we are guarding on the `id`, we are explicitly specializing the compiled artifact to the parameter.

This PR installs the parameters back into the graph. The benefit is removal of all pre-graph bytecode to extract the id-guarded tensors from locals/globals. This increases speedup from 1.67x to 1.75x for an internal model that has large number of optimizer parameters.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147824
Approved by: https://github.com/jansel

Co-authored-by: Jason Ansel <jansel@meta.com>
2025-02-28 03:22:11 +00:00
eellison
481a57bc37 Support torch.compile rng selective activation checkpointing with cudagraph (#146878)
TODO:
- [x]  Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync
- [x] Update rng state initialization to take from correct device
- [x]  Tests
- [x] handling of retain_graph
- [x] respect fallback random

Fix for https://github.com/pytorch/pytorch/issues/130123.

Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states.

We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward.

```
 ===== Forward graph 1 =====
 /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0);  fwd_rng_state_0 = None
        ...

 ===== Backward graph 1 =====
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0);  bwd_rng_state_0 = None
```

There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls:
- fwd0: fwd_rng_state0 -> fwd_rng_state1
- fwd1: fwd_rng_state1 -> fwd_rng_state2
- bwd1
- bwd0

Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary.

Other notes:

Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order.

Questions for reviewers:

This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`.

Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set

I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts.

Edit: updated to be taken from randint()

Update: initializing rng states from torch.randint..

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
2025-02-28 00:47:03 +00:00
IvanKobzarev
7ae0e0b2ea [aotd] Log torch._functorch.config in tlparse (#147883)
Adding torch._functorch.config to tlparse for better debugability.
E.g. https://github.com/pytorch/pytorch/pull/147638 happened only with `torch._functorch.config.view_replay_for_aliased_outputs=False` which is True by defautl

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147883
Approved by: https://github.com/bdhirsh, https://github.com/jamesjwu
2025-02-27 11:22:45 +00:00
PyTorch MergeBot
17358ce778 Revert "Support torch.compile rng selective activation checkpointing with cudagraph (#146878)"
This reverts commit ad0c879e22.

Reverted https://github.com/pytorch/pytorch/pull/146878 on behalf of https://github.com/wdvr due to lint failure ([comment](https://github.com/pytorch/pytorch/pull/146878#issuecomment-2686767956))
2025-02-27 03:36:16 +00:00
eellison
ad0c879e22 Support torch.compile rng selective activation checkpointing with cudagraph (#146878)
TODO:
- [x]  Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync
- [x] Update rng state initialization to take from correct device
- [x]  Tests
- [x] handling of retain_graph
- [x] respect fallback random

Fix for https://github.com/pytorch/pytorch/issues/130123.

Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states.

We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward.

```
 ===== Forward graph 1 =====
 /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0);  fwd_rng_state_0 = None
        ...

 ===== Backward graph 1 =====
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0);  bwd_rng_state_0 = None
```

There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls:
- fwd0: fwd_rng_state0 -> fwd_rng_state1
- fwd1: fwd_rng_state1 -> fwd_rng_state2
- bwd1
- bwd0

Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary.

Other notes:

Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order.

Questions for reviewers:

This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`.

Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set

I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts.

Edit: updated to be taken from randint()

Update: initializing rng states from torch.randint..

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
2025-02-27 02:08:29 +00:00
IvanKobzarev
8594856651 [aotd] Alias of intermediate unwrap TensorAlias (#147638)
Bug was reported by internal user.

AOTD classified outputs that are aliases of intermediates of the graph in different categories.

...
- output is alias of intermediate which base is already output
- output is alias of intermediate which base is not in output

If we look at the fn:
```
def fn(x):
    ix = x + 1
    a = ix.transpose(0, 1)
    return a.detach(), a
```

output 0: detach view of alias a, where a is already output
output 1: alias of intermediate ix, then additional output ix will be added internally

output 0 base is TensorAlias(a) in this case, but could be Tensor.
Adding runtime unwrapping solves this problem.

Alternatively we should track base of a.detach() all the way to ix, in that case the base will be always a Tensor, not TensorAlias.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147638
Approved by: https://github.com/bdhirsh
2025-02-26 19:42:21 +00:00
Brian Hirsh
89b9c12de8 remove prints from partitioner (#147749)
See c57894cd74..22d8f9a657 (r1968015955)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147749
Approved by: https://github.com/Skylion007, https://github.com/laithsakka
2025-02-24 21:03:45 +00:00
Aaron Orenstein
db4ce78d46 PEP585: More UP006 fixes (#146392)
This should be the final PR before we can enable RUFF UP006.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146392
Approved by: https://github.com/justinchuby, https://github.com/albanD, https://github.com/Skylion007
2025-02-20 06:18:13 +00:00
Laith Sakka
454fbd5bbe realize stride symbols in estimate_runtime (#146752)
Unfortuanlty could not create a local repo, or unit test.
fix https://github.com/pytorch/pytorch/issues/146686

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146752
Approved by: https://github.com/bobrenjc93, https://github.com/bdhirsh
2025-02-19 06:02:49 +00:00
Basil Wong
05001f0459 Add Structured Tracing for Traced Graph Edge Details for AC Debugging (#146634)
Summary:
Updating the structured trace infrastructure so that we are able to output to Zoomer and have an E2E solution.

Context Doc: https://docs.google.com/document/d/1T6omIBEWVhbOiwDLSLffgQwjxiT2rQv8QvvQwXkw4fY/edit?usp=sharing

Test Plan:
### Testing Structured Log + tlparse locally

Command:
```
TORCH_TRACE=/data/users/basilwong/fbsource/fbcode/log_torch_trace buck2 run mode/opt //aps_models/ads/icvr:icvr_launcher -- mode=local_fb_fm_v4 launcher.num_workers=2
```

Torch Trace Logs (local then sent to paste): P1686419449
```
cat log_torch_trace/dedicated_log_torch_trace_rank_0_2lg012xo.log | pastry
P1686419449
```

tlparse output: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpyiv5wj/rank_1/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=100

tlparse graph edge details output: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpyiv5wj/rank_1/9_0_0/joint_graph_information_397.txt?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=100

Differential Revision: D61557220

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146634
Approved by: https://github.com/jansel, https://github.com/Yuzhen11
2025-02-14 02:04:26 +00:00
Brian Hirsh
447a142de2 support input mutations on tangents in compile (#141131)
Fixes https://github.com/pytorch/pytorch/issues/141111. We previously supported mutations on saved activations that happened in the backward. This PR extends the support to tangents

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141131
Approved by: https://github.com/zou3519
2025-02-13 17:48:56 +00:00
Oguz Ulgen
076215944a Turn on autograd local caches in fbcode (#146996)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146996
Approved by: https://github.com/jamesjwu
2025-02-12 23:04:39 +00:00
Simon Fan
387c993c3b [ca] remove private API: _compiled_autograd_should_lift (#146720)
Since the functional autograd + compiled autograd migration, we don't trace into nodes anymore, and everything is lifted. We can't support this flag which tries to inline make_fx style in CA initial pass. There's no more usage internally.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146720
Approved by: https://github.com/zou3519
2025-02-10 04:29:57 +00:00
rzou
15b1ac3e86 Add torch.func.debug_unwrap (#146528)
Use it to unwrap any functorch-wrapped tensor. I don't recommend using
the output in a program since it breaks the semantics of the transforms,
but it seems useful for debugging.

I will note that some people have wanted to get intermediate values out
of an e.g. grad transform, so this might be a way to do that...

Test Plan:
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146528
Approved by: https://github.com/Chillee
2025-02-06 18:48:09 +00:00
Aaron Gokaslan
292af3cc89 [BE][Ez]: ISC001 Auto concatenate implicit one line strings (#146408)
Apply ruff rule about implicit string concatenation, this autofixes strings that are all the same type and on the same line. These lines are broken up likely as the result of autoformatters in the past. All fixes are automated using the autofixes in ISC001.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146408
Approved by: https://github.com/justinchuby, https://github.com/janeyx99
2025-02-04 19:07:04 +00:00
Sam Larsen
23fffb54d5 Use OrderedSet in _functorch/partitioners (#146102)
In an attempt to make partitioning more deterministic, change all sets in partitioners.py to OrderedSets. Note that this change does not fix the non-determinism we're seeing in the internal model. But let's at least eliminate this potential source of non-determinism before investigating any changes to the mincut approach?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146102
Approved by: https://github.com/oulgen
2025-02-04 17:43:07 +00:00
Animesh Jain
487400f47f [dynamo] Support functools.partial variables through inspect.signature (#146339)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146339
Approved by: https://github.com/jansel
ghstack dependencies: #146322, #146116
2025-02-04 04:39:39 +00:00
Tugsbayasgalan Manlaibaatar
041e08f9dc Add buffers to parameterizaiton rule (#145991)
Differential Revision: [D68959513](https://our.internmc.facebook.com/intern/diff/D68959513)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145991
Approved by: https://github.com/bdhirsh
2025-02-03 16:49:03 +00:00
Aaron Orenstein
ccbbc88bbb Turn on mypy for _dynamo/variables/builtin.py (#145552)
The fact that mypy errors were ignored was hiding several bugs in builtin.py (for example the previous diff's incorrect override and use of `call_getattr`)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145552
Approved by: https://github.com/anijain2305, https://github.com/Skylion007
ghstack dependencies: #145551
2025-01-30 22:21:32 +00:00
James Wu
d0aa1386b8 Disable AOTAutogradCache for triton version < 3.2 (#145937)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145937
Approved by: https://github.com/bdhirsh
2025-01-29 21:32:16 +00:00
Brian Hirsh
ed141d7d1a dont assign a size to _assert_scalar in partitioner (#143877)
Fixes https://github.com/pytorch/pytorch/issues/143876

Open to other suggestions - we have an invariant that all nodes in our ATen graphs should have a `meta['val']` field, but I don't think this is actually true in all cases, so I just hardcoded the invariant to ignore `_assert_scalar()` (which is a "special" op used in dynamic shapes for runtime asserts, and doesn't have a meta['val'] field)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143877
Approved by: https://github.com/zou3519
2025-01-29 16:21:37 +00:00
Brian Hirsh
7ca156f0ee partitioner: avoid inserting duplicates into heap (#145082)
Fixes https://github.com/pytorch/pytorch/issues/145081

This looks like it was a source of quadratic compile times in the torchtitan CP graphs. There's some code in the partitioner that iteratively adds users of a node to a heap, and pops the earliest user. If you have long parallel chains of fusible ops that all eventually feed into some shared ops, then this can result in:
(1) a node getting added to the heap many times
(2) each time we pop that node, we add (duplicates of) each of that node users to the heap
(3) repeat with each user

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145082
Approved by: https://github.com/xmfan
2025-01-28 23:44:45 +00:00
James Wu
d9ffa5da65 Log info for AOTAutogradCache bypasses instead of warning (#145768)
Fixes #145767

FxGraphCache also logs to info instead of warning so lets do that

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145768
Approved by: https://github.com/eellison, https://github.com/bdhirsh
2025-01-28 19:25:36 +00:00
James Wu
7c1fc0a047 Log cache state for AOTAutograd in title of file (#145715)
Differential Revision: [D68692755](https://our.internmc.facebook.com/intern/diff/D68692755/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145715
Approved by: https://github.com/bobrenjc93
2025-01-28 02:14:18 +00:00
rzou
ea141d8134 functional compiled autograd (#144707)
This PR squashes together the following commits:

https://github.com/pytorch/pytorch/pull/144115
https://github.com/pytorch/pytorch/pull/143417
https://github.com/pytorch/pytorch/pull/143405
https://github.com/pytorch/pytorch/pull/143387
https://github.com/pytorch/pytorch/pull/143304
https://github.com/pytorch/pytorch/pull/143296

This is a refactor of compiled autograd to use "functional autograd". The end goal is that it gets compiled autograd's initial capture to stop specializing on Tensor metadata, therefore allowing compiled autograd to better handle Tensor subclasses.

For more information, please read the commit messages for each PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144707
Approved by: https://github.com/bdhirsh, https://github.com/xmfan, https://github.com/jansel
2025-01-27 05:20:56 +00:00
Edward Z. Yang
90448f0128 Output of nonzero is transposed, fix fake tensor (#144695)
Needs this companion executorch PR: https://github.com/pytorch/executorch/pull/7657

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144695
Approved by: https://github.com/bobrenjc93, https://github.com/albanD
2025-01-26 01:07:22 +00:00
PyTorch MergeBot
16c4f8c395 Revert "[compiled autograd] Always proxy autograd.Function nodes; handle AOT backwards (#143405)"
This reverts commit ec820fe57c.

Reverted https://github.com/pytorch/pytorch/pull/143405 on behalf of https://github.com/izaitsevfb due to breaking internal tests T213390054 ([comment](https://github.com/pytorch/pytorch/pull/143296#issuecomment-2611224926))
2025-01-23 23:34:13 +00:00
PyTorch MergeBot
ab082863a1 Revert "[compiled autograd] support Tensor Subclasses in AOTBackward (#144115)"
This reverts commit 082c28c3c6.

Reverted https://github.com/pytorch/pytorch/pull/144115 on behalf of https://github.com/izaitsevfb due to breaking internal tests T213390054 ([comment](https://github.com/pytorch/pytorch/pull/143296#issuecomment-2611224926))
2025-01-23 23:34:12 +00:00
Aaron Gokaslan
5ebca3015d [BE]: Simplify set add with set update (#145152)
Simplifies the set update slightly to be more readable and efficient.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145152
Approved by: https://github.com/XuehaiPan, https://github.com/albanD

Co-authored-by: Xuehai Pan <XuehaiPan@outlook.com>
2025-01-23 20:18:13 +00:00
Li Yu (ads)
e6a84be3d3 [PyTorch] Add backend aot_eager_decomp_partition_with_mode (#143250)
Summary:
## Why
To make it possible to run torch dispatch mode inside compiled modules. This is to enable running MemoryTrackerMode (in next diff) to collect memory usage of compiled modules.

## What
Add a backend aot_eager_decomp_partition_with_mode.
Add an enable_log to the backend to control the compilation logging (which can be very verbose and slow the run of mode)

Test Plan:
unittest

E2e tested in the next diff which shows the memory read from the mode passed to this backend is very close to the actual job's memory snapshot.

Differential Revision: D67227144

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143250
Approved by: https://github.com/bdhirsh
2025-01-22 23:20:59 +00:00
PyTorch MergeBot
f0a210bf5d Revert "Output of nonzero is transposed, fix fake tensor (#144695)"
This reverts commit 693d8c7e94.

Reverted https://github.com/pytorch/pytorch/pull/144695 on behalf of https://github.com/izaitsevfb due to breaking internal tests, see D68461259 ([comment](https://github.com/pytorch/pytorch/pull/144695#issuecomment-2608443589))
2025-01-22 23:04:50 +00:00
PyTorch MergeBot
6e53588789 Revert "[BE]: Simplify set add with set update (#145152)"
This reverts commit 0cb9b2284a.

Reverted https://github.com/pytorch/pytorch/pull/145152 on behalf of https://github.com/davidberard98 due to land race with https://github.com/pytorch/pytorch/pull/145165 broke lint ([comment](https://github.com/pytorch/pytorch/pull/145152#issuecomment-2608378172))
2025-01-22 22:14:26 +00:00
rzou
082c28c3c6 [compiled autograd] support Tensor Subclasses in AOTBackward (#144115)
Compiled autograd's initial trace traces through the AOTBackward
epilogue. The Tensor Subclass code is not traceable. This PR changes it
so that when we see Tensor Subclass constructors, we proxy nodes for
their construction into the graph.

Test Plan:
- New basic test with TwoTensor
- Existing tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144115
Approved by: https://github.com/jansel, https://github.com/xmfan, https://github.com/bdhirsh
ghstack dependencies: #143296, #143304, #143387, #143405, #143417
2025-01-22 21:51:07 +00:00
rzou
ec820fe57c [compiled autograd] Always proxy autograd.Function nodes; handle AOT backwards (#143405)
We will always proxy autograd.Function nodes in compiled autograd's
initial graph capture (previously there was an
option to proxy vs trace into the autograd.Function)

We have some requirements for the AOTBackward. Compiled Autograd runs
accumulate grad reordering passes on the AOTBackward graph directly
after the initial graph capture, so we can't just proxy a single node for it.

Instead, we:
- proxy the AOTBackward prologue function into the CA graph
- copy-paste the AOTBackward graph into the CA graph
- trace directly through the epilogue (the traced nodes go into the CA
  graph).

Tracing through the epilogue is safe (assuming no Tensor subclasses)
because the only thing the epilogue does is drop some outputs. The
Tensor subclass situation was already broken so this doesn't regress
anything but this PR sets it up to be fixed (in a followup, where we
will proxy "make_subclass" calls into the graph from the epilogue).

Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143405
Approved by: https://github.com/jansel, https://github.com/xmfan
ghstack dependencies: #143296, #143304, #143387
2025-01-22 21:50:56 +00:00
Aaron Gokaslan
0cb9b2284a [BE]: Simplify set add with set update (#145152)
Simplifies the set update slightly to be more readable and efficient.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145152
Approved by: https://github.com/XuehaiPan, https://github.com/albanD

Co-authored-by: Xuehai Pan <XuehaiPan@outlook.com>
2025-01-22 21:31:13 +00:00
Edward Z. Yang
693d8c7e94 Output of nonzero is transposed, fix fake tensor (#144695)
Needs this companion executorch PR: https://github.com/pytorch/executorch/pull/7657

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144695
Approved by: https://github.com/bobrenjc93, https://github.com/albanD
2025-01-21 20:50:09 +00:00
Aaron Orenstein
78bff1e8c1 PEP585 update - torch/_functorch (#145139)
See #145101 for details.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145139
Approved by: https://github.com/bobrenjc93
2025-01-19 07:06:10 +00:00
PyTorch MergeBot
6c713ccb5e Revert "Make functionalization ViewMeta serializable with pickle. (#143712)"
This reverts commit b8abdaa286.

Reverted https://github.com/pytorch/pytorch/pull/143712 on behalf of https://github.com/kit1980 due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/143712#issuecomment-2597205261))
2025-01-17 00:52:50 +00:00
Yukio Siraichi
b8abdaa286 Make functionalization ViewMeta serializable with pickle. (#143712)
Fix: #141974

This PR makes `ViewMeta` sequence, present in functional tensors,
serializable with pickle. In order to accomplish that, it makes
`ViewMeta` an abstract class with overridable `forward` and `reverse`
functions. In this context, each operation that once instanciated
`ViewMeta`, should now create a new specialized class that inherits from
`ViewMeta. Therefore, this PR also uses codegen for creating these
specializations.

In summary, these are the changes this PR introduces:

- `ViewMeta` is turned into an abstract class (see
  _FunctionalStorageImpl.cpp_). `forward` and `reverse` are pure virtual
  functions that need to be implemented. `to_out_index` should be
  implemented by operations that might return more than 1 output.

- New `ViewMeta` specializations for `resize_` and `_unsafe_view` are
  created (see _FunctionalizeFallbackKernel.h_).

- New templates _ViewMetaClasses.{cpp,h}_ are created. They hold the
  declaration and definition of the `ViewMeta` specializations, which
  are automatically generated in the ATen codegen (see _gen.py_).

- New `_functionalization` Python sub-module is created (see
  _Module.cpp_). It serves as namespace for the `ViewMeta`
  specializations and `InverseReturnMode` enum.

- New template _ViewMetaClassesPythonBinding.cpp_ is created. It holds
  the automatically generated Python bindings for the `ViewMeta`
  specialization, which are generated in the torch codegen (see
  _generate_code.py_).

Note that this PR makes use of codegen at 2 different moments:

- ATen codegen (_gen.py_): generates the `ViewMeta` specialized classes.
- Torch codegen (_generate_code.py_): generated the Python bindings for
  them.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143712
Approved by: https://github.com/bdhirsh
2025-01-16 19:41:41 +00:00
James Wu
7d71ddbe5d Add non_c_binding torch functions to allowlist for AOTAutogradCache, confirm no special handlers for them (#144802)
Differential Revision: [D68173093](https://our.internmc.facebook.com/intern/diff/D68173093/)

This diff allows any function in torch_non_c_binding_in_graph_functions to be safe to cache. These functions should be safe to cache because they are part of the torch API, and do not save global state (or if they do, dynamo creates unique guards around the constants they return).
A function that's allowed in a dynamo graph is safe to cache for AOTAutograd purposes as long as:
- It's functional (i.e. does not access global state);
- or its value is constant folded away (and guarded against by dynamo)

The tricky cases are functions that dynamo uses special handlers to track. These special handlers can sometimes close over stuff that's safe for dynamo locally, but isn't encoded anywhere when cached across processes. An example of this is `DTensor.from_local`, where various DeviceMesh information doesn't change in the same dynamo process, but can change across multiple processes. The handler for `DTensor.from_local` closes over these and dynamo creates a proxy for the function call. This is not safe to cache.

That said, most special handlers are in fact functional and safe. So I add a unit test to test_trace_rules.py that confirms that any function with special handlers in dynamo added to this list needs to be audited to be safe to cache.

The list of safe handlers there either:
- Don't access global state;
- Guard on global state; or
- Always returns a constant that never changes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144802
Approved by: https://github.com/bdhirsh
2025-01-15 05:41:36 +00:00
Aaron Orenstein
8ad37ed710 Stop ignoring mypy errors in torch/testing/_internal/common_utils.py (#144483)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144483
Approved by: https://github.com/Skylion007
2025-01-14 22:32:51 +00:00
PyTorch MergeBot
dfe06e555d Revert "Stop ignoring mypy errors in torch/testing/_internal/common_utils.py (#144483)"
This reverts commit dcc04e9237.

Reverted https://github.com/pytorch/pytorch/pull/144483 on behalf of https://github.com/kit1980 due to Need to revert in order to revert https://github.com/pytorch/pytorch/pull/144441 ([comment](https://github.com/pytorch/pytorch/pull/144483#issuecomment-2588515018))
2025-01-14 00:46:48 +00:00
Aaron Orenstein
dcc04e9237 Stop ignoring mypy errors in torch/testing/_internal/common_utils.py (#144483)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144483
Approved by: https://github.com/Skylion007
2025-01-13 23:19:44 +00:00