Commit Graph

92 Commits

Author SHA1 Message Date
linhaifeng
369f2d6951 [3/N] fix typo in other folders (#166606)
fix typo in other folders

#166374
#166126

_typos.toml
```bash
[files]
extend-exclude = ["tools/linter/dictionary.txt"]
[default.extend-words]
nd = "nd"
arange = "arange"
Nd = "Nd"
GLOBALs = "GLOBALs"
hte = "hte"
iy = "iy"
PN = "PN"
Dout = "Dout"
optin = "optin"
gam = "gam"
PTD = "PTD"
Sur = "Sur"
nin = "nin"
tme = "tme"
inpt = "inpt"
mis = "mis"
Raison = "Raison"
ouput = "ouput"
nto = "nto"
Onwer = "Onwer"
callibrate = "callibrate"
ser = "ser"
Metdata = "Metdata"
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166606
Approved by: https://github.com/ezyang
2025-10-30 10:30:40 +00:00
Brian Hirsh
ed74dc054d add the option to disable functionalization in AOTDispatcher (#164577)
I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version:

(1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: https://github.com/pytorch/pytorch/pull/164939)

(2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164577
Approved by: https://github.com/ezyang
ghstack dependencies: #165372
2025-10-16 15:44:11 +00:00
Yuanyuan Chen
fbe0d20a17 [2/N] More ruff SIM fixes (#165031)
This is follow-up of #164695 to apply ruff SIM rules to more files. Most changes are about simplifying dict.get because None is already the default value.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165031
Approved by: https://github.com/mlazos
2025-10-14 14:22:54 +00:00
PyTorch MergeBot
b8be796a57 Revert "[2/N] More ruff SIM fixes (#165031)"
This reverts commit 38095fbd13.

Reverted https://github.com/pytorch/pytorch/pull/165031 on behalf of https://github.com/albanD due to One of the changed line started to fail on trunk ([comment](https://github.com/pytorch/pytorch/pull/165031#issuecomment-3390190870))
2025-10-10 13:42:14 +00:00
Yuanyuan Chen
38095fbd13 [2/N] More ruff SIM fixes (#165031)
This is follow-up of #164695 to apply ruff SIM rules to more files. Most changes are about simplifying dict.get because None is already the default value.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165031
Approved by: https://github.com/mlazos
2025-10-10 05:37:46 +00:00
Sherlock Huang
e532f62e0d Introduce joint_custom_pass callback (#164981)
```
        def joint_custom_pass(joint_gm: torch.fx.GraphModule, joint_inputs):
           # apply your pass for joint graph here

            return joint_gm

        class M(torch.nn.Module):
            def forward(self, x):
                return x.sin()

        x = torch.randn(10, requires_grad=False)
        compiled_fn = torch.compile(M(), backend="aot_eager")

        with torch._functorch.config.patch("joint_custom_pass", joint_custom_pass):
            out = compiled_fn(x)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164981
Approved by: https://github.com/ezyang, https://github.com/anijain2305
2025-10-09 04:40:54 +00:00
PyTorch MergeBot
5d7360bb03 Revert "Enable all SIM rules except disabled ones (#164645)"
This reverts commit 321e602692.

Reverted https://github.com/pytorch/pytorch/pull/164645 on behalf of https://github.com/izaitsevfb due to causes lint failures ([comment](https://github.com/pytorch/pytorch/pull/164645#issuecomment-3369274351))
2025-10-05 19:32:21 +00:00
Yuanyuan Chen
321e602692 Enable all SIM rules except disabled ones (#164645)
`SIM` rules are useful for simplifying boolean expressions and enhances code readability.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164645
Approved by: https://github.com/ezyang
2025-10-05 07:38:25 +00:00
James Wu
e156a07171 [Precompile] [RFC] Implement aot_compile_module (#162171)
This PR adds a new interface _aot_compile to `OptimizedModule`, so that the following is possible:

```
mod = SimpleLinearModule()
inputs = [
            ModelInput(
                args=(torch.randn(3, 3),),
                kwargs={},
                contexts=[torch.no_grad(), eval_mode(model)],
            ),
            ModelInput(
                args=(torch.randn(3, 3),), kwargs={}, contexts=[train_mode(model)]
            ),
        ]
        assert isinstance(model, torch._dynamo.eval_frame.OptimizedModule)
        model._aot_compile(
            inputs,
        )
```

After this PR, you can AOT precompile NanoGPT and use it to train directly. I'll share my fork of the repo to make this work.

## ModelInput
The `ModelInput` API is a work in progress; for now it represents a set of inputs and contexts to instruct the compiler to compile. Most commonly, this is "compile an eval mode with no grad, and  a training mode with grad", but also contains things like autocasting contexts, etc.

## Dispatch
Dispatching is super simple here, we just iterate through all the precompiled fullgraphs and check guards for each one until there's one htat passes. I'm a bit worried that having this in python code is going to be too expensive. The guard checks are happening in C++ anyway, though, so the only python bottlenecked step here is just the for loop, so perhaps the overhead will not be high. I'll work on measuring this, though.

## TODOs

This PR does not support `mod.compile()`, only `torch.compile(mod)`. In order to support `mod.compile()`, we'll need to update torch.nn.Module with an updated implementation — I can add that frontend later.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162171
Approved by: https://github.com/zhxchen17
2025-09-14 23:32:28 +00:00
Patrick C. Toulme
d46768db04 [MTIA] Allow users who know what they are doing to ignore all device mismatches in tracing and take a preferred device. (#159931)
Summary:
Device mismatches in tracing can most often be ignored. These are only logical mismatches not physical.

Take any intermediate computation, and that computation will not actually materialize in a compiled binary execution. So a device mismatch in the middle of the program is not real. The runtime will never materialize those tensors on CPU device during the execution, as they are temporary allocations.

If a user knows his tensors at graph input are all on the correct device, then he can ignore all tracing errors.

Users who know what they are doing should have an escape hatch to ignore any device mismatch in tracing.

Users can set
```
  torch._functorch.config.fake_tensor_prefer_device_type = 'mtia'
```
to forcefully override any mismatch and prefer the non cpu device. This unblocks vLLM graph mode for MTIA.

Test Plan:
Added two unit tests.

Rollback Plan:

Differential Revision: D79698438

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159931
Approved by: https://github.com/jansel
2025-08-07 22:37:15 +00:00
James Wu
fb462cec8d Normalize placeholder names in AOTAutogradCache (#157916)
This PR adds a pass to sanitize_gm_for_cache which normalizes all placeholder names across input dynamo graphs to AOTAutogradCache. This is safe because nothing underneath AOTAutograd uses the node names on the
original dynamo graph: AOTAutograd re-traces with its own nodes, and guards are
in terms of original sources rather than placeholder names.

Note that the dynamo output graphs traced by tlparse will not show this change because it's done before this sanitization step. The aot autograd outputs also will not change because AOTAutograd's own traced graphs don't use the original placeholders of the dynamo graph. Thus, this change is essentially a no-op from everyone's perspective except for cache key checks.

Fixes #157792

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157916
Approved by: https://github.com/zou3519
2025-07-14 17:45:11 +00:00
Xuehai Pan
7f14b42adf [BE][2/16] fix typos in torch/ (torch/_*/) (#156312)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156312
Approved by: https://github.com/albanD
2025-07-12 05:47:06 +00:00
PyTorch MergeBot
e15f4248ad Revert "[BE][2/16] fix typos in torch/ (torch/_*/) (#156312)"
This reverts commit 7a92b51196.

Reverted https://github.com/pytorch/pytorch/pull/156312 on behalf of https://github.com/XuehaiPan due to landrace ([comment](https://github.com/pytorch/pytorch/pull/156312#issuecomment-3064672250))
2025-07-12 04:40:52 +00:00
Xuehai Pan
7a92b51196 [BE][2/16] fix typos in torch/ (torch/_*/) (#156312)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156312
Approved by: https://github.com/albanD
2025-07-12 01:47:22 +00:00
Shuai Yang
6c42afe196 Introduce sync_cross_rank_decision (#156287)
Summary:
This is an improvement over `_broadcast_rank0_decision` where we uses the rank0's decision to broadcast to every rank. The issue of `_broadcast_rank0_decision` is that we observed large variance on the peak memory usage. One cause is that different ranks receive different dynamic shaped tensors and the hints of those tensors are different in different ranks. If we only rely on rank0's decision and it's unlucky to get unrepresentative hints, then the decision it makes may not be suitable for other ranks.

Here, we introduce `sync_cross_rank_decision` which comes up with the decision after comparing all ranks' local decision, it will:
1. all gather decisions from all ranks;
2. test each decision on the current rank and get its estimated memory usage;
3. all reduce estimated memory usage with ReduceOp.MAX, so that we know the maximum memory usage of each decision on all ranks;
4. pick the decision which gives us minimum maximum memory memory usage;

A graph to show more details
https://internalfb.com/excalidraw/EX484509

After applying sync_cross_rank_decision, we observed that the variance are much smaller

Rollback Plan:

Differential Revision: D76714005

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156287
Approved by: https://github.com/fmassa, https://github.com/bdhirsh
2025-07-03 23:43:53 +00:00
rzou
aa2d54148d Add AOTDispatcher config to set backward autocast behavior (#156356)
This PR adds a new config `backward_pass_autocast`, to set the backward autocast
behavior. It does not change the existing behavior.

The reason why we need this is that torch.compile acquires a forward and
backward graph at the time of the forward pass. This means that
implemented naively, if there are any context managers active outside
the call to torch.compile, the backward graph will also get the
behaviors from those context managers. This PR gives users a way to
tweak the autocast behavior of the backward pass.

Please see torch._functorch.config for the options to the
`backward_pass_autocast` config.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156356
Approved by: https://github.com/bdhirsh
ghstack dependencies: #155354
2025-06-27 14:58:58 +00:00
Xuehai Pan
162ca185ff [BE][PYFMT] migrate PYFMT for torch/_[a-h]*/ to ruff format (#144551)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144551
Approved by: https://github.com/ezyang
ghstack dependencies: #148186
2025-06-25 06:16:06 +00:00
IvanKobzarev
4439255148 [aotd] Support saved tensors hooks in aot_autograd (#150032)
https://github.com/pytorch/pytorch/issues/148222

Goal:

At the moment autograd saved tensors hooks are run in eager after compiled forward.
They are executed at the same time for all saved tensors.
Hooks can be used to reduce amout of memory used for saved tensors, doing quantization or offloading to cpu.
This is suboptimal for optimization of peak memory.
Better solution will be to put the hooks in the graph, as close as possible to the last usage of the tensor.

To get user specified autograd saved tensors hooks in the graph.

Logic:

UX:
If user specifies with torch.autograd.graph.saved_tensors_hooks(pack_gm, unpack_gm).
Where pack_gm and unpack_gm are torch.fx.GraphModule.
Then AotAutograd will retrace those graph modules, doing decompositions and functionalization in aot_autograd, inlining the result graphs in forward epilogue and backward prologue.

User may want to use control logic in the hooks, for example applying quantization only for specific dtypes and sizes.

This is also possible, user can put it into torch.fx.wrap function and use symbolic trace to make a GraphModule.

In that case AotAutograd cahing will work only in case when user explicitly set to the torch.fx.wrap call_function node "user_cache_hash" metadata.

If this metadata set - then aot_autograd cache can use saved cache artifact.
If metadata is not set - then cache is bypassed.

Dynamo:
Dynamo traces pack and unpack hooks and installs them as subgraph and explicitly adds to the output_graph. (As those subgraphs are not used and will not be copied in the result by default).

The complexity here is that at this moment we do not have example of inputs for the hooks.
We trace  pack_hook with some Tensor from the inputs.
The result subgraphs are added to the hashing of AotAutograd Cache.

In AotAutograd we retrace the graph with the true saved tensors coming from partitioner.

Backwards Compatibility:
As current hooks are executed in eager mode and not all of them will be traceable - we only try to put in the graph hooks, explicitly marked by user with annotation (@_inlineable_saved_tensors_hooks).
For other hooks or if compiled autograd is enabled - keep the same logic.

Recompilations:
Hooks are guarded with lambda guard matching function id to cause recompilation if user reruns compiled function.

Aot_autograd:
After partitioner prepared forward and backward module - we trace prepared at Dynamo graphs for pack and unpack hooks and inline them in epilogue of forward and prologue of backward. Forward outputs and backward inputs are changed, transparently for user.

We do not try to put it close the last usage etc., relying on inductor to do this optimization.

```
INFO: TRACED GRAPH
 ===== Forward graph pre saved_tensors_hooks inlining 3 =====
 /data/users/ivankobzarev/a/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1][s1, 1]cuda:0"):
         # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6660 in simple_fn, code: x = x + 1
        add: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.add.Tensor(primals_3, 1);  primals_3 = None

         # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6661 in simple_fn, code: x = SAF.apply(x)
        view: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.view.default(add, [primals_1, primals_2])
        return (view, add, primals_1, primals_2)

INFO: TRACED GRAPH
 ===== Backward graph pre saved_tensors_hooks inlining 3 =====
 /data/users/ivankobzarev/a/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1][s1, 1]cuda:0"):
         # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6660 in simple_fn, code: x = x + 1
        add: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.add.Tensor(primals_3, 1);  primals_3 = None

         # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6661 in simple_fn, code: x = SAF.apply(x)
        view: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.view.default(add, [primals_1, primals_2])
        return (view, add, primals_1, primals_2)

INFO: TRACED GRAPH
 ===== saved_tensors_pack_hook add 3 =====
 /data/users/ivankobzarev/a/pytorch/torch/fx/_lazy_graph_module.py class pack_float8(torch.nn.Module):
    def forward(self, x_1: "f32[s0, s1][s1, 1]cuda:0"):
        # No stacktrace found for following nodes
        _to_copy: "f8e4m3fn[s0, s1][s1, 1]cuda:0" = torch.ops.aten._to_copy.default(x_1, dtype = torch.float8_e4m3fn);  x_1 = None
        return (torch.float32, _to_copy)

INFO: TRACED GRAPH
 ===== saved_tensors_unpack_hook add 3 =====
 <eval_with_key>.22 from /data/users/ivankobzarev/a/pytorch/torch/fx/experimental/proxy_tensor.py:1225 in wrapped class pack_float8(torch.nn.Module):
    def forward(self, x_1: "f32[s0, s1][s1, 1]cuda:0"):
        # No stacktrace found for following nodes
        _to_copy: "f8e4m3fn[s0, s1][s1, 1]cuda:0" = torch.ops.aten._to_copy.default(x_1, dtype = torch.float8_e4m3fn);  x_1 = None
        return (torch.float32, _to_copy)

INFO: TRACED GRAPH
 ===== Forward graph 3 =====
 /data/users/ivankobzarev/a/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1][s1, 1]cuda:0"):
         # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6660 in simple_fn, code: x = x + 1
        add: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.add.Tensor(primals_3, 1);  primals_3 = None

        # No stacktrace found for following nodes
        _to_copy: "f8e4m3fn[s0, s1][s1, 1]cuda:0" = torch.ops.aten._to_copy.default(add, dtype = torch.float8_e4m3fn)

         # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6661 in simple_fn, code: x = SAF.apply(x)
        view: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.view.default(add, [primals_1, primals_2]);  add = None
        return (view, _to_copy, primals_1, primals_2)

INFO: TRACED GRAPH
 ===== Backward graph 3 =====
 <eval_with_key>.21 class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", add_packed_2: "f8e4m3fn[s0, s1][s1, 1]cuda:0", tangents_1: "f32[s0, s1][s1, 1]cuda:0"):
        # No stacktrace found for following nodes
        _to_copy: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten._to_copy.default(add_packed_2, dtype = torch.float32);  add_packed_2 = None

         # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6661 in simple_fn, code: x = SAF.apply(x)
        add_7: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.add.Tensor(tangents_1, _to_copy);  tangents_1 = _to_copy = None
        return (None, None, add_7)

```

Differential Revision: [D72187044](https://our.internmc.facebook.com/intern/diff/D72187044)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150032
Approved by: https://github.com/bdhirsh
2025-05-22 14:09:38 +00:00
James Wu
c31e239910 [precompile] Add BundledAOTAutogradCacheEntry (#152840)
Finally, this PR adds BundledAOTAutogradCacheEntry. A BundledAOTAutogradCacheEntry is an AOTAutogradCacheEntry that saves the entire CompiledFxGraph directly in the entry.

This has some advantages:
- No more dependency on FxGraphCache at all
- Clearing FxGraphCache does not result in AOTAutogradCache miss
- Simpler logic, as BundledAOTAutogradCacheEntry has everything you need to load a full compiled python wrapper from a dynamo output

We plan to use BundledAOTAutogradCacheEntry for precompile. There's also a question of whether we want to use it for regular caching — the main disadvantage of this is having to save the same CompiledFxGraph twice, once in Inductor cache and once for AOTAutogradCache. With MegaCaching, this *could* be a regression in total cache size (as well as a minor cold start regression, as you have to save the same graph twice). I will import this and measure the mega cache space complexity, and if it looks good I'll enable it by default for caching as well.

On warm start, if AOTAutogradCache hits, you won't have to load inductor at all, so warm start overhead should be unaffected.

Differential Revision: [D74593304](https://our.internmc.facebook.com/intern/diff/D74593304)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152840
Approved by: https://github.com/zhxchen17
2025-05-21 18:08:42 +00:00
rzou
3d777bae10 Inductor respects exact strides on custom ops by default (#150511)
If a tag is not specified on a custom operator, then inductor will
assume that it needs exact strides.

Test Plan:
- tests + CI

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150511
Approved by: https://github.com/eellison, https://github.com/shunting314
ghstack dependencies: #148104
2025-05-03 00:02:24 +00:00
rzou
2b37a726e0 Refactor layout constraint selection logic (#148104)
This PR:

- cleans up some existing comments that don't make sense anymore
- hooks up the "custom_op_default_layout_constraint" back (that seems to
have broken)
- cleans up the "lazy registration path" which seems to never get hit
anymore
- adds dislike_padding to nodes that require exact strides

Test Plan:
- tests + CI

disable padding

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148104
Approved by: https://github.com/shunting314, https://github.com/eellison
2025-05-03 00:02:24 +00:00
Francisco Massa
89c0c3ca80 Add private config to broadcast rank0 decision from the partitioner to all ranks (#152264)
Summary: This PR adds a private configuration to the partitioner that ensures that the decision taken is the same across all ranks. This is a temporary workaround, as when size_hints are also taken into account in compiler collectives this workaround will not be needed anymore.

Test Plan:
This has been tested on some internal models, but I haven't added any tests in PyTorch (yet?)
T

Differential Revision: D73666017

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152264
Approved by: https://github.com/bdhirsh
2025-04-29 21:27:57 +00:00
Oguz Ulgen
3cf0e2d8ec Add inductor standalone_compile API (#150670)
This PR adds standalone_compile API that does precompilation via caching to support vLLM use case in the short term while we work on the longer term precompilation solution.

```
standalone_compile(gm, example_inputs, options) -> CompiledArtifact
CompiledArtifact.save(path, format: binary|unpacked = binary)
CompiledArtifact.load(path, format: binary|unpacked = binary)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150670
Approved by: https://github.com/jamesjwu, https://github.com/zou3519
2025-04-15 23:38:15 +00:00
PyTorch MergeBot
74f6bc28a7 Revert "Add inductor standalone_compile API (#150670)"
This reverts commit c9aef50898.

Reverted https://github.com/pytorch/pytorch/pull/150670 on behalf of https://github.com/Camyll due to breaking internal builds with torch module not found error ([comment](https://github.com/pytorch/pytorch/pull/150670#issuecomment-2806975267))
2025-04-15 17:35:59 +00:00
Oguz Ulgen
c9aef50898 Add inductor standalone_compile API (#150670)
This PR adds standalone_compile API that does precompilation via caching to support vLLM use case in the short term while we work on the longer term precompilation solution.

```
standalone_compile(gm, example_inputs, options) -> CompiledArtifact
CompiledArtifact.save(path, format: binary|unpacked = binary)
CompiledArtifact.load(path, format: binary|unpacked = binary)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150670
Approved by: https://github.com/jamesjwu, https://github.com/zou3519
2025-04-14 22:00:09 +00:00
PyTorch MergeBot
24b3ab9255 Revert "Add inductor standalone_compile API (#150670)"
This reverts commit bbc5fe8504.

Reverted https://github.com/pytorch/pytorch/pull/150670 on behalf of https://github.com/albanD due to Broke profiler test ([comment](https://github.com/pytorch/pytorch/pull/150670#issuecomment-2802067144))
2025-04-14 15:22:33 +00:00
Oguz Ulgen
bbc5fe8504 Add inductor standalone_compile API (#150670)
This PR adds standalone_compile API that does precompilation via caching to support vLLM use case in the short term while we work on the longer term precompilation solution.

```
standalone_compile(gm, example_inputs, options) -> CompiledArtifact
CompiledArtifact.save(path, format: binary|unpacked = binary)
CompiledArtifact.load(path, format: binary|unpacked = binary)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150670
Approved by: https://github.com/jamesjwu, https://github.com/zou3519
2025-04-14 07:07:10 +00:00
IvanKobzarev
25309a17f0 [aotd] Config to guess_tangents_stride (#150035)
Differential Revision: [D71907684](https://our.internmc.facebook.com/intern/diff/D71907684)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150035
Approved by: https://github.com/ilyas409, https://github.com/seemethere
2025-03-28 13:54:19 +00:00
James Wu
49f86a939c [AOTAutogradCache] Allow Custom Autograd functions behind a flag (#149751)
This adds a new env var and flag,

autograd_cache_allow_custom_autograd_functions, (env var: `TORCHINDUCTOR_AUTOGRAD_CACHE_ALLOW_CUSTOM_AUTOGRAD`) which allows custom autograd functions into AOTAutogradCache.

@hirsheybar and I worked together to verify that the higher order op AutogradFunctionApply is pure with respect to the dynamo input being passed in, so this *should* be safe. I'm still putting it behind a flag and turning it on slowly, first on an internal model, though. Once we verify that it is correct on the internal model we can work to enable the flag by default.

Differential Revision: [D71633184](https://our.internmc.facebook.com/intern/diff/D71633184/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149751
Approved by: https://github.com/bdhirsh, https://github.com/zou3519
2025-03-24 21:12:11 +00:00
Brian Hirsh
f06e366532 partitioner: treat inputs with static indices as free to save (#148922)
Fixes https://github.com/pytorch/pytorch/issues/141881

internal xref: https://fb.workplace.com/groups/1075192433118967/posts/1538435030128036/?comment_id=1556782068293332

I tried to make a test case out of the code linked in that github issue. The setup + bad outcome today was as follows:

(1) you have a graph where one of its inputs is a model weight

(2) in the backward, you do some downstream compute on `weight`, `tmp = f(weight)`, where (a) `tmp` is of a smaller size than `weight`, and (b) the compute is trivially fusible into other kernels (so the partitioner thinks it is "free" to recompute

(3) since `sizeof(tmp) < sizeof(weight)` and the recompute is free, the partitioner decides that it would be strictly better to save `tmp` for backward instead of weight

(4) this is bad: `weight` is a static tensor that sits in GPU memory for the duration of your entire training loop, so saving it for backward has no negative impact on peak memory.  Since we're saving `tmp` instead, we end up unnecessarily increasing peak memory. In particular - the repro involves an autograd.Function in eager that saves the weight for bw, so we end up hitting higher peak memory in compile

The fix I'm trying out in this PR is to tell the partitioner that graph inputs that we know have static addresses (aka parameters) are "free" to save.

Below is the fw/bw graph before my change, where you can see that instead of `primals_2` being saved for backward, we save `t_8` (which involves some low precision downstream compute on `primals_2`, that is only needed in the backward.

```
 ===== Forward graph 0 =====
 /data/users/hirsheybar/checkout2/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "bf16[64, 64][64, 1]cuda:0", primals_2: "bf16[64, 64][64, 1]cuda:0", primals_3: "bf16[64][1]cuda:0"):
         # File: /data/users/hirsheybar/checkout2/pytorch/test/dynamo/test_repros.py:6943 in forward, code: out = Fp8LinearFn.apply(
        abs_1: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.abs.default(primals_1)
        view: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(abs_1, [64, 1, 64]);  abs_1 = None
        amax: "bf16[64, 1][1, 1]cuda:0" = torch.ops.aten.amax.default(view, [-1]);  view = None
        abs_2: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.abs.default(primals_2)
        view_1: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(abs_2, [64, 1, 64]);  abs_2 = None
        amax_1: "bf16[64, 1][1, 1]cuda:0" = torch.ops.aten.amax.default(view_1, [-1]);  view_1 = None
        _to_copy: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten._to_copy.default(amax, dtype = torch.float32);  amax = None
        clamp: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.clamp.default(_to_copy, 1e-12);  _to_copy = None
        div: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.div.Tensor(clamp, 448.0);  clamp = None
        reciprocal: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.reciprocal.default(div)
        view_2: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(primals_1, [64, 1, 64])
        view_3: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_2, [64, 1, 1, 64]);  view_2 = None
        slice_1: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(reciprocal, 0, 0, 9223372036854775807);  reciprocal = None
        unsqueeze: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_1, 1);  slice_1 = None
        slice_2: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze, 2, 0, 9223372036854775807);  unsqueeze = None
        unsqueeze_1: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_2, 3);  slice_2 = None
        mul: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_3, unsqueeze_1);  view_3 = unsqueeze_1 = None
        view_4: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul, [64, 1, 64]);  mul = None
        view_5: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_4, [64, 64]);  view_4 = None
        _to_copy_1: "f8e4m3fn[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_5, dtype = torch.float8_e4m3fn);  view_5 = None
        _to_copy_2: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten._to_copy.default(amax_1, dtype = torch.float32)
        clamp_1: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.clamp.default(_to_copy_2, 1e-12);  _to_copy_2 = None
        div_1: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.div.Tensor(clamp_1, 448.0);  clamp_1 = None
        reciprocal_1: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.reciprocal.default(div_1)
        view_6: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(primals_2, [64, 1, 64])
        view_7: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_6, [64, 1, 1, 64]);  view_6 = None
        slice_3: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(reciprocal_1, 0, 0, 9223372036854775807);  reciprocal_1 = None
        unsqueeze_2: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_3, 1);  slice_3 = None
        slice_4: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_2, 2, 0, 9223372036854775807);  unsqueeze_2 = None
        unsqueeze_3: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_4, 3);  slice_4 = None
        mul_1: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_7, unsqueeze_3);  view_7 = unsqueeze_3 = None
        view_8: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul_1, [64, 1, 64]);  mul_1 = None
        view_9: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_8, [64, 64]);  view_8 = None
        _to_copy_3: "f8e4m3fn[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_9, dtype = torch.float8_e4m3fn);  view_9 = None
        t: "f32[1, 64][1, 1]cuda:0" = torch.ops.aten.t.default(div_1);  div_1 = None
        new_ones: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.new_ones.default(div, [1, 1], pin_memory = False)
        new_ones_1: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.new_ones.default(t, [1, 1], pin_memory = False)
        t_2: "f8e4m3fn[64, 64][1, 64]cuda:0" = torch.ops.aten.t.default(_to_copy_3);  _to_copy_3 = None
        t_3: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.t.default(new_ones_1);  new_ones_1 = None
        _scaled_mm: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten._scaled_mm.default(_to_copy_1, t_2, new_ones, t_3, None, None, torch.bfloat16);  _to_copy_1 = t_2 = new_ones = t_3 = None
        view_10: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(_scaled_mm, [64, 1, 64]);  _scaled_mm = None
        view_11: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_10, [64, 1, 1, 64]);  view_10 = None
        slice_5: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(div, 0, 0, 9223372036854775807);  div = None
        unsqueeze_4: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_5, 1);  slice_5 = None
        slice_6: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_4, 2, 0, 9223372036854775807);  unsqueeze_4 = None
        unsqueeze_5: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_6, 3);  slice_6 = None
        mul_2: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_11, unsqueeze_5);  view_11 = unsqueeze_5 = None
        view_12: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul_2, [64, 1, 64]);  mul_2 = None
        view_13: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_12, [64, 64]);  view_12 = None
        view_14: "f32[1, 64, 64][4096, 64, 1]cuda:0" = torch.ops.aten.view.default(view_13, [1, 64, 64]);  view_13 = None
        view_15: "f32[1, 64, 64, 1][4096, 64, 1, 1]cuda:0" = torch.ops.aten.view.default(view_14, [1, 64, 64, 1]);  view_14 = None
        slice_7: "f32[1, 64][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(t, 0, 0, 9223372036854775807);  t = None
        unsqueeze_6: "f32[1, 1, 64][1, 64, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_7, 1);  slice_7 = None
        slice_8: "f32[1, 1, 64][1, 64, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_6, 2, 0, 9223372036854775807);  unsqueeze_6 = None
        unsqueeze_7: "f32[1, 1, 64, 1][1, 64, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_8, 3);  slice_8 = None
        mul_3: "f32[1, 64, 64, 1][4096, 64, 1, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_15, unsqueeze_7);  view_15 = unsqueeze_7 = None
        view_16: "f32[64, 64, 1][64, 1, 1]cuda:0" = torch.ops.aten.view.default(mul_3, [64, 64, 1]);  mul_3 = None
        view_17: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_16, [64, 64]);  view_16 = None
        _to_copy_4: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_17, dtype = torch.bfloat16);  view_17 = None
        add: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.add.Tensor(_to_copy_4, primals_3);  _to_copy_4 = primals_3 = None
        t_4: "bf16[64, 64][1, 64]cuda:0" = torch.ops.aten.t.default(primals_2);  primals_2 = None
        clone: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.clone.default(t_4, memory_format = torch.contiguous_format);  t_4 = None
        t_5: "bf16[1, 64][1, 1]cuda:0" = torch.ops.aten.t.default(amax_1);  amax_1 = None
        view_21: "bf16[1, 1, 64][1, 64, 1]cuda:0" = torch.ops.aten.view.default(t_5, [1, 1, 64]);  t_5 = None
        amax_3: "bf16[1, 1][1, 1]cuda:0" = torch.ops.aten.amax.default(view_21, [-1]);  view_21 = None
        unsqueeze_8: "bf16[1, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(amax_3, 1);  amax_3 = None
        expand: "bf16[1, 64, 1][1, 0, 1]cuda:0" = torch.ops.aten.expand.default(unsqueeze_8, [1, 64, 1])
        clone_1: "bf16[1, 64, 1][64, 1, 1]cuda:0" = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format);  expand = None
        view_22: "bf16[64, 1][1, 1]cuda:0" = torch.ops.aten.view.default(clone_1, [64, 1]);  clone_1 = None
        _to_copy_7: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten._to_copy.default(view_22, dtype = torch.float32);  view_22 = None
        clamp_3: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.clamp.default(_to_copy_7, 1e-12);  _to_copy_7 = None
        div_3: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.div.Tensor(clamp_3, 448.0);  clamp_3 = None
        reciprocal_3: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.reciprocal.default(div_3);  div_3 = None
        view_27: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(clone, [64, 1, 64]);  clone = None
        view_28: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_27, [64, 1, 1, 64]);  view_27 = None
        slice_11: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(reciprocal_3, 0, 0, 9223372036854775807);  reciprocal_3 = None
        unsqueeze_11: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_11, 1);  slice_11 = None
        slice_12: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_11, 2, 0, 9223372036854775807);  unsqueeze_11 = None
        unsqueeze_12: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_12, 3);  slice_12 = None
        mul_5: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_28, unsqueeze_12);  view_28 = unsqueeze_12 = None
        view_29: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul_5, [64, 1, 64]);  mul_5 = None
        view_30: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_29, [64, 64]);  view_29 = None
        _to_copy_8: "f8e4m3fn[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_30, dtype = torch.float8_e4m3fn);  view_30 = None
        t_8: "f8e4m3fn[64, 64][1, 64]cuda:0" = torch.ops.aten.t.default(_to_copy_8);  _to_copy_8 = None

        # No stacktrace found for following nodes
        view_39: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(add, [64, 64]);  add = None
        return (view_39, primals_1, unsqueeze_8, t_8)

INFO: TRACED GRAPH
 ===== Backward graph 0 =====
 <eval_with_key>.1 class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "bf16[64, 64][64, 1]cuda:0", unsqueeze_8: "bf16[1, 1, 1][1, 1, 1]cuda:0", t_8: "f8e4m3fn[64, 64][1, 64]cuda:0", tangents_1: "bf16[64, 64][64, 1]cuda:0"):
         # File: /data/users/hirsheybar/checkout2/pytorch/test/dynamo/test_repros.py:6946 in forward, code: out = out.unflatten(0, input.shape[:-1])
        view_19: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(tangents_1, [64, 64]);  tangents_1 = None

         # File: /data/users/hirsheybar/checkout2/pytorch/test/dynamo/test_repros.py:6943 in forward, code: out = Fp8LinearFn.apply(
        abs_3: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.abs.default(view_19)
        view_20: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(abs_3, [64, 1, 64]);  abs_3 = None
        amax_2: "bf16[64, 1][1, 1]cuda:0" = torch.ops.aten.amax.default(view_20, [-1]);  view_20 = None
        expand: "bf16[1, 64, 1][1, 0, 1]cuda:0" = torch.ops.aten.expand.default(unsqueeze_8, [1, 64, 1]);  unsqueeze_8 = None
        clone_1: "bf16[1, 64, 1][64, 1, 1]cuda:0" = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format);  expand = None
        view_22: "bf16[64, 1][1, 1]cuda:0" = torch.ops.aten.view.default(clone_1, [64, 1]);  clone_1 = None
        _to_copy_5: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten._to_copy.default(amax_2, dtype = torch.float32);  amax_2 = None
        clamp_2: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.clamp.default(_to_copy_5, 1e-12);  _to_copy_5 = None
        div_2: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.div.Tensor(clamp_2, 448.0);  clamp_2 = None
        reciprocal_2: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.reciprocal.default(div_2)
        view_23: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_19, [64, 1, 64])
        view_24: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_23, [64, 1, 1, 64]);  view_23 = None
        slice_9: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(reciprocal_2, 0, 0, 9223372036854775807);  reciprocal_2 = None
        unsqueeze_9: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_9, 1);  slice_9 = None
        slice_10: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_9, 2, 0, 9223372036854775807);  unsqueeze_9 = None
        unsqueeze_10: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_10, 3);  slice_10 = None
        mul_4: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_24, unsqueeze_10);  view_24 = unsqueeze_10 = None
        view_25: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul_4, [64, 1, 64]);  mul_4 = None
        view_26: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_25, [64, 64]);  view_25 = None
        _to_copy_6: "f8e4m3fn[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_26, dtype = torch.float8_e4m3fn);  view_26 = None
        _to_copy_7: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten._to_copy.default(view_22, dtype = torch.float32);  view_22 = None
        clamp_3: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.clamp.default(_to_copy_7, 1e-12);  _to_copy_7 = None
        div_3: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.div.Tensor(clamp_3, 448.0);  clamp_3 = None
        t_6: "f32[1, 64][1, 1]cuda:0" = torch.ops.aten.t.default(div_3);  div_3 = None
        new_ones_2: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.new_ones.default(div_2, [1, 1], pin_memory = False)
        new_ones_3: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.new_ones.default(t_6, [1, 1], pin_memory = False)
        t_9: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.t.default(new_ones_3);  new_ones_3 = None
        _scaled_mm_1: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten._scaled_mm.default(_to_copy_6, t_8, new_ones_2, t_9, None, None, torch.bfloat16);  _to_copy_6 = t_8 = new_ones_2 = t_9 = None
        view_31: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(_scaled_mm_1, [64, 1, 64]);  _scaled_mm_1 = None
        view_32: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_31, [64, 1, 1, 64]);  view_31 = None
        slice_13: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(div_2, 0, 0, 9223372036854775807);  div_2 = None
        unsqueeze_13: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_13, 1);  slice_13 = None
        slice_14: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_13, 2, 0, 9223372036854775807);  unsqueeze_13 = None
        unsqueeze_14: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_14, 3);  slice_14 = None
        mul_6: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_32, unsqueeze_14);  view_32 = unsqueeze_14 = None
        view_33: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul_6, [64, 1, 64]);  mul_6 = None
        view_34: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_33, [64, 64]);  view_33 = None
        view_35: "f32[1, 64, 64][4096, 64, 1]cuda:0" = torch.ops.aten.view.default(view_34, [1, 64, 64]);  view_34 = None
        view_36: "f32[1, 64, 64, 1][4096, 64, 1, 1]cuda:0" = torch.ops.aten.view.default(view_35, [1, 64, 64, 1]);  view_35 = None
        slice_15: "f32[1, 64][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(t_6, 0, 0, 9223372036854775807);  t_6 = None
        unsqueeze_15: "f32[1, 1, 64][1, 64, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_15, 1);  slice_15 = None
        slice_16: "f32[1, 1, 64][1, 64, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_15, 2, 0, 9223372036854775807);  unsqueeze_15 = None
        unsqueeze_16: "f32[1, 1, 64, 1][1, 64, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_16, 3);  slice_16 = None
        mul_7: "f32[1, 64, 64, 1][4096, 64, 1, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_36, unsqueeze_16);  view_36 = unsqueeze_16 = None
        view_37: "f32[64, 64, 1][64, 1, 1]cuda:0" = torch.ops.aten.view.default(mul_7, [64, 64, 1]);  mul_7 = None
        view_38: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_37, [64, 64]);  view_37 = None
        _to_copy_9: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_38, dtype = torch.bfloat16);  view_38 = None
        t_10: "bf16[64, 64][1, 64]cuda:0" = torch.ops.aten.t.default(view_19)
        mm: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.mm.default(t_10, primals_1);  t_10 = primals_1 = None
        sum_1: "bf16[64][1]cuda:0" = torch.ops.aten.sum.dim_IntList(view_19, [0]);  view_19 = None
        return (_to_copy_9, mm, sum_1)

```

With the change, we save primals_2 for backward instead

```
 ===== Forward graph 0 =====
 /data/users/hirsheybar/checkout2/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "bf16[64, 64][64, 1]cuda:0", primals_2: "bf16[64, 64][64, 1]cuda:0", primals_3: "bf16[64][1]cuda:0"):
         # File: /data/users/hirsheybar/checkout2/pytorch/test/dynamo/test_repros.py:6943 in forward, code: out = Fp8LinearFn.apply(
        abs_1: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.abs.default(primals_1)
        view: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(abs_1, [64, 1, 64]);  abs_1 = None
        amax: "bf16[64, 1][1, 1]cuda:0" = torch.ops.aten.amax.default(view, [-1]);  view = None
        abs_2: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.abs.default(primals_2)
        view_1: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(abs_2, [64, 1, 64]);  abs_2 = None
        amax_1: "bf16[64, 1][1, 1]cuda:0" = torch.ops.aten.amax.default(view_1, [-1]);  view_1 = None
        _to_copy: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten._to_copy.default(amax, dtype = torch.float32);  amax = None
        clamp: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.clamp.default(_to_copy, 1e-12);  _to_copy = None
        div: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.div.Tensor(clamp, 448.0);  clamp = None
        reciprocal: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.reciprocal.default(div)
        view_2: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(primals_1, [64, 1, 64])
        view_3: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_2, [64, 1, 1, 64]);  view_2 = None
        slice_1: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(reciprocal, 0, 0, 9223372036854775807);  reciprocal = None
        unsqueeze: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_1, 1);  slice_1 = None
        slice_2: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze, 2, 0, 9223372036854775807);  unsqueeze = None
        unsqueeze_1: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_2, 3);  slice_2 = None
        mul: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_3, unsqueeze_1);  view_3 = unsqueeze_1 = None
        view_4: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul, [64, 1, 64]);  mul = None
        view_5: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_4, [64, 64]);  view_4 = None
        _to_copy_1: "f8e4m3fn[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_5, dtype = torch.float8_e4m3fn);  view_5 = None
        _to_copy_2: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten._to_copy.default(amax_1, dtype = torch.float32)
        clamp_1: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.clamp.default(_to_copy_2, 1e-12);  _to_copy_2 = None
        div_1: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.div.Tensor(clamp_1, 448.0);  clamp_1 = None
        reciprocal_1: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.reciprocal.default(div_1)
        view_6: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(primals_2, [64, 1, 64])
        view_7: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_6, [64, 1, 1, 64]);  view_6 = None
        slice_3: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(reciprocal_1, 0, 0, 9223372036854775807);  reciprocal_1 = None
        unsqueeze_2: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_3, 1);  slice_3 = None
        slice_4: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_2, 2, 0, 9223372036854775807);  unsqueeze_2 = None
        unsqueeze_3: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_4, 3);  slice_4 = None
        mul_1: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_7, unsqueeze_3);  view_7 = unsqueeze_3 = None
        view_8: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul_1, [64, 1, 64]);  mul_1 = None
        view_9: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_8, [64, 64]);  view_8 = None
        _to_copy_3: "f8e4m3fn[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_9, dtype = torch.float8_e4m3fn);  view_9 = None
        t: "f32[1, 64][1, 1]cuda:0" = torch.ops.aten.t.default(div_1);  div_1 = None
        new_ones: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.new_ones.default(div, [1, 1], pin_memory = False)
        new_ones_1: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.new_ones.default(t, [1, 1], pin_memory = False)
        t_2: "f8e4m3fn[64, 64][1, 64]cuda:0" = torch.ops.aten.t.default(_to_copy_3);  _to_copy_3 = None
        t_3: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.t.default(new_ones_1);  new_ones_1 = None
        _scaled_mm: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten._scaled_mm.default(_to_copy_1, t_2, new_ones, t_3, None, None, torch.bfloat16);  _to_copy_1 = t_2 = new_ones = t_3 = None
        view_10: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(_scaled_mm, [64, 1, 64]);  _scaled_mm = None
        view_11: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_10, [64, 1, 1, 64]);  view_10 = None
        slice_5: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(div, 0, 0, 9223372036854775807);  div = None
        unsqueeze_4: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_5, 1);  slice_5 = None
        slice_6: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_4, 2, 0, 9223372036854775807);  unsqueeze_4 = None
        unsqueeze_5: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_6, 3);  slice_6 = None
        mul_2: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_11, unsqueeze_5);  view_11 = unsqueeze_5 = None
        view_12: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul_2, [64, 1, 64]);  mul_2 = None
        view_13: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_12, [64, 64]);  view_12 = None
        view_14: "f32[1, 64, 64][4096, 64, 1]cuda:0" = torch.ops.aten.view.default(view_13, [1, 64, 64]);  view_13 = None
        view_15: "f32[1, 64, 64, 1][4096, 64, 1, 1]cuda:0" = torch.ops.aten.view.default(view_14, [1, 64, 64, 1]);  view_14 = None
        slice_7: "f32[1, 64][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(t, 0, 0, 9223372036854775807);  t = None
        unsqueeze_6: "f32[1, 1, 64][1, 64, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_7, 1);  slice_7 = None
        slice_8: "f32[1, 1, 64][1, 64, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_6, 2, 0, 9223372036854775807);  unsqueeze_6 = None
        unsqueeze_7: "f32[1, 1, 64, 1][1, 64, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_8, 3);  slice_8 = None
        mul_3: "f32[1, 64, 64, 1][4096, 64, 1, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_15, unsqueeze_7);  view_15 = unsqueeze_7 = None
        view_16: "f32[64, 64, 1][64, 1, 1]cuda:0" = torch.ops.aten.view.default(mul_3, [64, 64, 1]);  mul_3 = None
        view_17: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_16, [64, 64]);  view_16 = None
        _to_copy_4: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_17, dtype = torch.bfloat16);  view_17 = None
        add: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.add.Tensor(_to_copy_4, primals_3);  _to_copy_4 = primals_3 = None
        t_5: "bf16[1, 64][1, 1]cuda:0" = torch.ops.aten.t.default(amax_1);  amax_1 = None
        view_21: "bf16[1, 1, 64][1, 64, 1]cuda:0" = torch.ops.aten.view.default(t_5, [1, 1, 64]);  t_5 = None
        amax_3: "bf16[1, 1][1, 1]cuda:0" = torch.ops.aten.amax.default(view_21, [-1]);  view_21 = None
        unsqueeze_8: "bf16[1, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(amax_3, 1);  amax_3 = None

        # No stacktrace found for following nodes
        view_39: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(add, [64, 64]);  add = None
        return (view_39, primals_1, primals_2, unsqueeze_8)

INFO: TRACED GRAPH
 ===== Backward graph 0 =====
 <eval_with_key>.1 class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "bf16[64, 64][64, 1]cuda:0", primals_2: "bf16[64, 64][64, 1]cuda:0", unsqueeze_8: "bf16[1, 1, 1][1, 1, 1]cuda:0", tangents_1: "bf16[64, 64][64, 1]cuda:0"):
         # File: /data/users/hirsheybar/checkout2/pytorch/test/dynamo/test_repros.py:6946 in forward, code: out = out.unflatten(0, input.shape[:-1])
        view_19: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(tangents_1, [64, 64]);  tangents_1 = None

         # File: /data/users/hirsheybar/checkout2/pytorch/test/dynamo/test_repros.py:6943 in forward, code: out = Fp8LinearFn.apply(
        t_4: "bf16[64, 64][1, 64]cuda:0" = torch.ops.aten.t.default(primals_2);  primals_2 = None
        clone: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.clone.default(t_4, memory_format = torch.contiguous_format);  t_4 = None
        abs_3: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.abs.default(view_19)
        view_20: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(abs_3, [64, 1, 64]);  abs_3 = None
        amax_2: "bf16[64, 1][1, 1]cuda:0" = torch.ops.aten.amax.default(view_20, [-1]);  view_20 = None
        expand: "bf16[1, 64, 1][1, 0, 1]cuda:0" = torch.ops.aten.expand.default(unsqueeze_8, [1, 64, 1]);  unsqueeze_8 = None
        clone_1: "bf16[1, 64, 1][64, 1, 1]cuda:0" = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format);  expand = None
        view_22: "bf16[64, 1][1, 1]cuda:0" = torch.ops.aten.view.default(clone_1, [64, 1]);  clone_1 = None
        _to_copy_5: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten._to_copy.default(amax_2, dtype = torch.float32);  amax_2 = None
        clamp_2: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.clamp.default(_to_copy_5, 1e-12);  _to_copy_5 = None
        div_2: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.div.Tensor(clamp_2, 448.0);  clamp_2 = None
        reciprocal_2: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.reciprocal.default(div_2)
        view_23: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_19, [64, 1, 64])
        view_24: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_23, [64, 1, 1, 64]);  view_23 = None
        slice_9: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(reciprocal_2, 0, 0, 9223372036854775807);  reciprocal_2 = None
        unsqueeze_9: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_9, 1);  slice_9 = None
        slice_10: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_9, 2, 0, 9223372036854775807);  unsqueeze_9 = None
        unsqueeze_10: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_10, 3);  slice_10 = None
        mul_4: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_24, unsqueeze_10);  view_24 = unsqueeze_10 = None
        view_25: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul_4, [64, 1, 64]);  mul_4 = None
        view_26: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_25, [64, 64]);  view_25 = None
        _to_copy_6: "f8e4m3fn[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_26, dtype = torch.float8_e4m3fn);  view_26 = None
        _to_copy_7: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten._to_copy.default(view_22, dtype = torch.float32);  view_22 = None
        clamp_3: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.clamp.default(_to_copy_7, 1e-12);  _to_copy_7 = None
        div_3: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.div.Tensor(clamp_3, 448.0);  clamp_3 = None
        reciprocal_3: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.reciprocal.default(div_3)
        view_27: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(clone, [64, 1, 64]);  clone = None
        view_28: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_27, [64, 1, 1, 64]);  view_27 = None
        slice_11: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(reciprocal_3, 0, 0, 9223372036854775807);  reciprocal_3 = None
        unsqueeze_11: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_11, 1);  slice_11 = None
        slice_12: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_11, 2, 0, 9223372036854775807);  unsqueeze_11 = None
        unsqueeze_12: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_12, 3);  slice_12 = None
        mul_5: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_28, unsqueeze_12);  view_28 = unsqueeze_12 = None
        view_29: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul_5, [64, 1, 64]);  mul_5 = None
        view_30: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_29, [64, 64]);  view_29 = None
        _to_copy_8: "f8e4m3fn[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_30, dtype = torch.float8_e4m3fn);  view_30 = None
        t_6: "f32[1, 64][1, 1]cuda:0" = torch.ops.aten.t.default(div_3);  div_3 = None
        new_ones_2: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.new_ones.default(div_2, [1, 1], pin_memory = False)
        new_ones_3: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.new_ones.default(t_6, [1, 1], pin_memory = False)
        t_8: "f8e4m3fn[64, 64][1, 64]cuda:0" = torch.ops.aten.t.default(_to_copy_8);  _to_copy_8 = None
        t_9: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.t.default(new_ones_3);  new_ones_3 = None
        _scaled_mm_1: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten._scaled_mm.default(_to_copy_6, t_8, new_ones_2, t_9, None, None, torch.bfloat16);  _to_copy_6 = t_8 = new_ones_2 = t_9 = None
        view_31: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(_scaled_mm_1, [64, 1, 64]);  _scaled_mm_1 = None
        view_32: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_31, [64, 1, 1, 64]);  view_31 = None
        slice_13: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(div_2, 0, 0, 9223372036854775807);  div_2 = None
        unsqueeze_13: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_13, 1);  slice_13 = None
        slice_14: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_13, 2, 0, 9223372036854775807);  unsqueeze_13 = None
        unsqueeze_14: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_14, 3);  slice_14 = None
        mul_6: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_32, unsqueeze_14);  view_32 = unsqueeze_14 = None
        view_33: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul_6, [64, 1, 64]);  mul_6 = None
        view_34: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_33, [64, 64]);  view_33 = None
        view_35: "f32[1, 64, 64][4096, 64, 1]cuda:0" = torch.ops.aten.view.default(view_34, [1, 64, 64]);  view_34 = None
        view_36: "f32[1, 64, 64, 1][4096, 64, 1, 1]cuda:0" = torch.ops.aten.view.default(view_35, [1, 64, 64, 1]);  view_35 = None
        slice_15: "f32[1, 64][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(t_6, 0, 0, 9223372036854775807);  t_6 = None
        unsqueeze_15: "f32[1, 1, 64][1, 64, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_15, 1);  slice_15 = None
        slice_16: "f32[1, 1, 64][1, 64, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_15, 2, 0, 9223372036854775807);  unsqueeze_15 = None
        unsqueeze_16: "f32[1, 1, 64, 1][1, 64, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_16, 3);  slice_16 = None
        mul_7: "f32[1, 64, 64, 1][4096, 64, 1, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_36, unsqueeze_16);  view_36 = unsqueeze_16 = None
        view_37: "f32[64, 64, 1][64, 1, 1]cuda:0" = torch.ops.aten.view.default(mul_7, [64, 64, 1]);  mul_7 = None
        view_38: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_37, [64, 64]);  view_37 = None
        _to_copy_9: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_38, dtype = torch.bfloat16);  view_38 = None
        t_10: "bf16[64, 64][1, 64]cuda:0" = torch.ops.aten.t.default(view_19)
        mm: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.mm.default(t_10, primals_1);  t_10 = primals_1 = None
        sum_1: "bf16[64][1]cuda:0" = torch.ops.aten.sum.dim_IntList(view_19, [0]);  view_19 = None
        return (_to_copy_9, mm, sum_1)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148922
Approved by: https://github.com/zou3519
2025-03-18 20:08:11 +00:00
Brian Hirsh
3646d4dbc8 [partitioner] always ban compiler-driven recompute of collectives by default (#147561)
This should fix the hang in https://fb.workplace.com/groups/1075192433118967/permalink/1603268720311333/

The argument here is that:

(1) in general, it is not safe for the partitioner to sometimes choose to recompute collectives in the backward. Why? If we are running a distributed job, where many ranks are compiling at the same time, we need every rank to make a consistent decision about which collectives are recomputed for backward. If we let each compiler instance make its own choice without any cross-rank communication, they can make different choices and cause NCCL hangs (see the link above)

(2) later on, we'll want an `spmd_mode` flag that causes the compiler to issue collectives and communicate info across ranks. Once we have such a config, then turning it on should make it safe for the partitioner to potentially choose to recompute collectives (and agree on the binary "recompute-or-save" choice across all ranks)

(3) even without an `spmd_mode`, users can override this choice by using `torch.utils.checkpoint()` in their user code. User checkpointing generally always overrides the partitioner, and this should be safe because we expect the user to apply checkpointing consistently across ranks

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147561
Approved by: https://github.com/zou3519
2025-03-13 03:36:13 +00:00
Luca Wehrstedt
f80aad62fa Improve Pareto frontier plot for AutoAC (#148678)
This was added in https://github.com/pytorch/pytorch/pull/126320. It's a very nice feature, which can be used to predict memory usage for different budget values.

However, it had some limitations, notably in terms of resolution (it only sampled 21 points across the whole range thus missed many threshold values) and in distributed settings.

Here I fix those by using recursive binary searches to identify all thresholds (up to a resolution of 1e-3, which can be made configurable) and output them in SVG (to be able to discern different points), plus I add the rank to the filename and store it in a user-define directory.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148678
Approved by: https://github.com/Chillee, https://github.com/fmassa
2025-03-07 13:22:29 +00:00
eellison
481a57bc37 Support torch.compile rng selective activation checkpointing with cudagraph (#146878)
TODO:
- [x]  Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync
- [x] Update rng state initialization to take from correct device
- [x]  Tests
- [x] handling of retain_graph
- [x] respect fallback random

Fix for https://github.com/pytorch/pytorch/issues/130123.

Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states.

We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward.

```
 ===== Forward graph 1 =====
 /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0);  fwd_rng_state_0 = None
        ...

 ===== Backward graph 1 =====
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0);  bwd_rng_state_0 = None
```

There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls:
- fwd0: fwd_rng_state0 -> fwd_rng_state1
- fwd1: fwd_rng_state1 -> fwd_rng_state2
- bwd1
- bwd0

Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary.

Other notes:

Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order.

Questions for reviewers:

This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`.

Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set

I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts.

Edit: updated to be taken from randint()

Update: initializing rng states from torch.randint..

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
2025-02-28 00:47:03 +00:00
PyTorch MergeBot
17358ce778 Revert "Support torch.compile rng selective activation checkpointing with cudagraph (#146878)"
This reverts commit ad0c879e22.

Reverted https://github.com/pytorch/pytorch/pull/146878 on behalf of https://github.com/wdvr due to lint failure ([comment](https://github.com/pytorch/pytorch/pull/146878#issuecomment-2686767956))
2025-02-27 03:36:16 +00:00
eellison
ad0c879e22 Support torch.compile rng selective activation checkpointing with cudagraph (#146878)
TODO:
- [x]  Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync
- [x] Update rng state initialization to take from correct device
- [x]  Tests
- [x] handling of retain_graph
- [x] respect fallback random

Fix for https://github.com/pytorch/pytorch/issues/130123.

Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states.

We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward.

```
 ===== Forward graph 1 =====
 /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0);  fwd_rng_state_0 = None
        ...

 ===== Backward graph 1 =====
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0);  bwd_rng_state_0 = None
```

There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls:
- fwd0: fwd_rng_state0 -> fwd_rng_state1
- fwd1: fwd_rng_state1 -> fwd_rng_state2
- bwd1
- bwd0

Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary.

Other notes:

Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order.

Questions for reviewers:

This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`.

Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set

I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts.

Edit: updated to be taken from randint()

Update: initializing rng states from torch.randint..

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
2025-02-27 02:08:29 +00:00
Oguz Ulgen
076215944a Turn on autograd local caches in fbcode (#146996)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146996
Approved by: https://github.com/jamesjwu
2025-02-12 23:04:39 +00:00
Yidi Wu
c7dbee5106 [reland][export] don't decompose custom triton op when exporting (#144284)
Summary:
A reland of https://github.com/pytorch/pytorch/pull/142426.

Copying the description over here:

For torch.export (strict and non-strict), we don't do functional decomposition. Instead, we preserve the custom triton ops as custom ops. This is because we want the exported program to be high-level and serializable.

The alternative:
If we decompose the custom op to a functional hop and make it a node in exported program, we need to figure out ways of serializing the hop and its arguments, which can be triton.jited python functions and triton dtypes. This is undesireble because:

it can be tedious to maintain layer that serialize the jited function (e.g. with a string) and dtypes.
changes to triton or the serialization logic for triton arguments can be BC breaking
exported program will expose the implementation detail (i.e. triton source code) for a specific backend (GPU) to users, which mixes levels of abstraction.

Future plans:
After this PR, in the short term, we expect users to have a seperate aot_compile stage that compiles the exported program into a Cubin file on the same machine that users call export, which does autotuning and removes triton dependency and serve the model with Cubin. This guarantees that triton changes won't break BC.

In the long term, we may export multiple cubins for the triton op directly.

Test Plan: see new tests.

Differential Revision: D67879685

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144284
Approved by: https://github.com/zou3519
2025-01-11 01:34:35 +00:00
James Wu
fbbafd0320 Turn on AOTAutogradCache by default on open source (#141981)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141981
Approved by: https://github.com/bdhirsh, https://github.com/oulgen
2024-12-12 04:21:11 +00:00
IvanKobzarev
661d1f0372 [aotd] non-contiguous NestedTensor mutation in compile (#139630)
Allow mutations mutations for subclasses that are non-contiguous.

Changes:

Removing assert in collect_metadata_analysis

Main requested testcase:
Compilation of NJT.index_put()

Adding test in test_nestedtensor.py, that compiles NJT.index_put()

It is  decomposed to NJT split,unbind, which  needed additional `torch._check`, `torch._check_is_size` for NJT.unbind()  and guard_size_oblivious() usage in _meta_registrations and _inductor/lowering.py.

Special case:
If tangent is mutated outside of the graph, it does not participate in backward graph. Autograd in this case will set this tangent to zeros tensor.

We handle it separately in CompiledFunction.backward: not doing any processing for this tangent and broadcast to number of expected subclass unwrapped arguments.

disabling for dynamo 2 tests:
1/ For nested tensor - symbolic shapes issue on nested_tensor index operation that does splits [0, 0, 0] - there is a failure with "pending unbacked symints". This PR does not add more .tolist()/item() ops than it was before.

2/ As we do not fail with exception in collect_metadata_analysis new paths for dynamo started working and it started failing with smth strange that set_ in storage_offset (because of test for views) handling updates storage "cpu" -> "meta"

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139630
Approved by: https://github.com/bdhirsh
2024-12-06 12:18:46 +00:00
Yukio Siraichi
470b775d7a Remove functorch config: _max_aliased_inputs_with_dynamic_shapes_enabled. (#141680)
This PR removes the functorch config that set an upper limit on the number of aliased
inputs with dynamic shapes. After moving them to be run at runtime in C++, the compilation
time and runtime (in true alias cases) improved, rendering the error no longer relevant.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141680
Approved by: https://github.com/bdhirsh
ghstack dependencies: #139554, #139555, #140013
2024-12-05 14:43:58 +00:00
Pian Pawakapan
1132b6764a [draft export] generate fake outputs when real tensor prop finds mismatches (#139766)
Currently real tensor tracing raises MetadataMismatchErrors if registered fake kernels don't match the real kernels (e.g. shape, aliasing, dtype, etc.). This adds an option to use fake kernel inference to bypass mismatches - this option defaults to False for real tensor tracing, but is on for draft export.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139766
Approved by: https://github.com/angelayi, https://github.com/zou3519
2024-11-21 08:01:09 +00:00
Boyuan Feng
87059d4547 [AOTAutograd] Handle edge cases for donated buffer & enable in oss (#139669)
This PR enables donated buffer in OSS and handles two edge cases:

1. While donated buffer relies on storage to check alias, sparse tensor subclasses does not provide access to storage. So we skip sparse tensor subclasses for donated buffer.
2. Handles missing "val" from n.meta. This is observed from `inductor/test_fused_attention.py::SDPAPatternRewriterCpuTests::test_sdpa_rewriter_11_cpu`,
`functorch/test_aotdispatch.py::TestAOTAutograd::test_input_mutation_simple_with_none_and_nontensor`, and
`inductor/test_compiled_autograd.py::TestCompiledAutograd::test_trace_run_with_rng_state`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139669
Approved by: https://github.com/bdhirsh
2024-11-05 18:38:20 +00:00
eellison
ee2f8a50d3 Class rename (#139490)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139490
Approved by: https://github.com/exclamaforte, https://github.com/zou3519
ghstack dependencies: #139295
2024-11-02 00:10:17 +00:00
eellison
fe18a221eb Add debug backend that applies CrossRefFakeMode, use in compiler bisector (#138651)
I was debugging an internal ne divergence for a while that ended up being because of a bad meta. I added an explicit a config option and an explicit backend `aot_eager_decomp_partition_crossref` to enable the FakeCrossRefMode when running the graph.  I added an explicit backend bc I suspect it will be useful for internal models but I'm also happy to leave as config option.

It will only test ops that have meta to avoid memory overhead of hitting fallback path and running in eager.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138651
Approved by: https://github.com/zou3519, https://github.com/bdhirsh
2024-10-25 15:58:36 +00:00
PyTorch MergeBot
796c3c3415 Revert "Disallow FakeTensor.data_ptr access in eager mode (#137221)"
This reverts commit 7e13e7dd7e.

Reverted https://github.com/pytorch/pytorch/pull/137221 on behalf of https://github.com/jovianjaison due to failing internal tests ([comment](https://github.com/pytorch/pytorch/pull/137221#issuecomment-2397957081))
2024-10-07 21:46:13 +00:00
rzou
7e13e7dd7e Disallow FakeTensor.data_ptr access in eager mode (#137221)
Previously we raised a deprecation warning (beginning PyTorch 2.4). Now
that we are on 2.6, we're completing the deprecation and disallowing
this behavior.

Test Plan:
- tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137221
Approved by: https://github.com/albanD, https://github.com/eellison
2024-10-03 23:47:55 +00:00
James Wu
4d3c0fc061 [AOTAutogradCache] add config for AOTAutograd remote cache (#137011)
Summary: This just adds a config option and JK for turning on remote AOTAutogradCache. It does not implement anything with the new options being passed in. That will come next diff.

This PR also changes the command for turning on the local AOTAutogradCache to be more consistent to that of FXGraphCache: TORCHINDUCTOR_AUTOGRAD_CACHE

Test Plan: Existing tests should pass and should build

Reviewed By: oulgen

Differential Revision: D63321965

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137011
Approved by: https://github.com/oulgen
2024-10-03 16:03:47 +00:00
Xuehai Pan
e7eeee473c [BE][Easy][14/19] enforce style for empty lines in import segments in torch/_[a-c]*/ and torch/_[e-h]*/ and torch/_[j-z]*/ (#129765)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129765
Approved by: https://github.com/ezyang
2024-07-31 10:42:50 +00:00
Boyuan Feng
40cc5c0697 [AOT Autograd] Donated Buffer (#130580)
Implements donated buffer feature and adds unit tests. Donated buffer is a saved tensor that is not aliased with forward inputs, fw_outputs (except saved tensors), and bw_outputs. We detect donated buffers during `aot_dispatch_autograd` and store donated buffers in `ViewAndMutationMetadata`, such that it can be accssed in inductor.

Fixes #129496

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130580
Approved by: https://github.com/bdhirsh
2024-07-26 17:14:34 +00:00
Aaron Orenstein
567482973d typing fake_tensor.py (#128041)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128041
Approved by: https://github.com/eellison
ghstack dependencies: #129182
2024-07-13 06:07:40 +00:00