Commit Graph

831 Commits

Author SHA1 Message Date
Prajesh Praveen Anchalia
48e9ffc873 Unify on dynamo_compile as the overall wait counter (#150293)
Summary:
dynamo_compile for the most part has been accounting for compile time except autotuning.

all_compilation_types had earlier been injected on fx_codegen_and_compile, which was incorrect.

Add autotuining to dynamo and deprcate all_compilation_types counter.

Differential Revision: D72145447

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150293
Approved by: https://github.com/masnesral, https://github.com/jamesjwu
2025-04-01 08:55:51 +00:00
Prajesh Praveen Anchalia
005c9b2f4f Fix _Waitcounter decorator and dd backward pass wait counter (#150235)
Summary:
This will log a wait counter with for backward compile and fixes weirdness with nested context managers.

Since the old wait counters added through dynamo_timed were never created with the nesting issue. I am also changing the key nomenclature from `pytorch.dynamo_timed` to `pytorch.wait_counter`. We want to use the same nomenclature, to make it easy to find keys.

Reviewed By: jamesjwu

Differential Revision: D72032055

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150235
Approved by: https://github.com/jamesjwu, https://github.com/masnesral
2025-03-30 05:20:12 +00:00
IvanKobzarev
25309a17f0 [aotd] Config to guess_tangents_stride (#150035)
Differential Revision: [D71907684](https://our.internmc.facebook.com/intern/diff/D71907684)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150035
Approved by: https://github.com/ilyas409, https://github.com/seemethere
2025-03-28 13:54:19 +00:00
Tugsbayasgalan Manlaibaatar
c49315e645 Improve attr mismatch msg (#149576)
Differential Revision: [D71513041](https://our.internmc.facebook.com/intern/diff/D71513041)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149576
Approved by: https://github.com/avikchaudhuri
2025-03-28 05:10:56 +00:00
James Wu
49f86a939c [AOTAutogradCache] Allow Custom Autograd functions behind a flag (#149751)
This adds a new env var and flag,

autograd_cache_allow_custom_autograd_functions, (env var: `TORCHINDUCTOR_AUTOGRAD_CACHE_ALLOW_CUSTOM_AUTOGRAD`) which allows custom autograd functions into AOTAutogradCache.

@hirsheybar and I worked together to verify that the higher order op AutogradFunctionApply is pure with respect to the dynamo input being passed in, so this *should* be safe. I'm still putting it behind a flag and turning it on slowly, first on an internal model, though. Once we verify that it is correct on the internal model we can work to enable the flag by default.

Differential Revision: [D71633184](https://our.internmc.facebook.com/intern/diff/D71633184/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149751
Approved by: https://github.com/bdhirsh, https://github.com/zou3519
2025-03-24 21:12:11 +00:00
James Wu
fe954cdcbf Use correct boxed_forward_device_index when running CompiledFxGraph.post_compile (#148130)
This PR threads through the correct boxed_forward_device_index from graph_kwargs to CompiledFXGraph.post_compile. This allows us to correctly update BoxedDeviceIndex from cache hits.

We don't actually need to save `boxed_forward_device_index` in CompiledFXGraph because its value is in the cache key, so it always matches to the ambient one anyway. On forward with cudagraphs enabled, derive `boxed_forward_device_index`'s value from `device_idxs`.

Testing:

```
python benchmarks/dynamo/cachebench.py --mode training --benchmark torchbench --model BERT_pytorch --device cuda --repeat 1 --dynamic --output="dynamic.json"
```

Now cache hits properly on FXGraphCache. AOTAutogradCache has a guard failure. Will look into that as a followup.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148130
Approved by: https://github.com/eellison
2025-03-23 02:57:58 +00:00
Brian Hirsh
7f836b747f partitioner: ensure collectives saved by SAC that are actually unused in the bw are properly not saved (#149652)
This PR fixes one of the issues described here: https://github.com/pytorch/torchtitan/issues/866#issuecomment-2726015248

I spent some time trying to write a unit test and ultimately failed. If folks are interested I can spend more time trying to, but otherwise I have an E2E test with torchtitan. command:
```
CUDA_VISIBLE_DEVICES=1,2,3,4 NGPU=4 CONFIG_FILE="./torchtitan/models/llama/train_configs/llama3_8b.toml" tlp ./run_train.sh --training.steps=30  --training.tensor_parallel_degree=2 --training.compile --experimental.enable_async_tensor_parallel
```

here's the backward graph generated prior to the PR: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/hirsheybar/f7d17388-42c2-4d7e-8a55-a00387341ecb/custom/rank_0/-_0_0_0/aot_backward_graph_9.txt?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000

and new backward graph with the PR: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/hirsheybar/ab8576fc-98c1-4915-af47-699aa8e2557e/custom/rank_0/-_0_0_0/aot_backward_graph_9.txt?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000

The main difference is that the input arg `reduce_scatter_tensor_1` is dead code in the bw graph, causing us to unnecessarily save a giant `reduce_scatter` for bw. With the PR, we properly ensure that it is not saved for backward.

More comments in the PR, but the main thing going on is that:

(1) We have some existing logic that checks for activations that are actually dead code in the backward, and removes them

(2) collectives are not properly handled by this code. Why? collective are **always** followed by  `wait_tensor()` call. So we need to go one node further and check if the "dead" code has a wait_tensor user that is also dead

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149652
Approved by: https://github.com/zou3519
ghstack dependencies: #149514
2025-03-21 22:09:19 +00:00
Nikita Shulga
68dfd44e50 Do not depend on numpy during the import (#149683)
But a good followup would be to use torch primitives instead of numpy here
Fixes https://github.com/pytorch/pytorch/issues/149681

Test plan: Monkey-patch 2.7.0-rc and run `python -c "import torch;print(torch.compile(lambda x:x.sin() + x.cos())(torch.rand(32)))"`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149683
Approved by: https://github.com/seemethere
2025-03-21 08:14:57 +00:00
Simon Fan
e481615bc7 [aot] always lower the backward with a deepcopy (#149229)
FIXES https://github.com/pytorch/pytorch/issues/149105

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149229
Approved by: https://github.com/bdhirsh
2025-03-21 01:47:13 +00:00
Avik Chaudhuri
6237495fcf torch.Size input (#149414)
Summary: Support for `torch.Size` inputs was patchy before because `unflatten_fn` for this type returned a tuple. This PR cleans this up.

Fixes #149158

Test Plan: added test

Differential Revision: D71403635

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149414
Approved by: https://github.com/yushangdi
2025-03-20 16:23:13 +00:00
IvanKobzarev
2c4bc65366 [aotd] Guess tangents stride as output strides (#144579)
AOTDispatch  doing AOT backward graph preparation does not know real tangents that user will specify when runs backward.

AOTD guesses the tangents. Before - we guessed that memory format of tangents will be as memory format of corresponding outputs. And if specified tangents at runtime are not the same memory format as we guessed during compilation, AOTD does coercion (copy) to guessed memory_format

But as Horace found, there are popular use cases, where the outputs of compiled region will be in specific memory_format. E.g. in 4D tensor transposing dims 1 and 2.

https://github.com/karpathy/nanoGPT/blob/master/model.py#L57

This PR changes the logic, that AOTD expects the same "strideness" of tangents as outputs. As a result it will avoid coercion for the case of transposed dims.

Limitations:
We keep guessing memory_format for:
1/ Dynamic shapes (needs more changes)
2/ Tensor subclasses (needs more changes)

Other changes:
test_torchinductor was always creating contiguous tangents via `torch.randn()`, changing them to be `torch.randn_like()` to compare computation with the same strideness.

(E.g. for cuda float16 strideness affects numerics for fft ops).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144579
Approved by: https://github.com/bdhirsh
2025-03-20 15:41:36 +00:00
Brian Hirsh
f06e366532 partitioner: treat inputs with static indices as free to save (#148922)
Fixes https://github.com/pytorch/pytorch/issues/141881

internal xref: https://fb.workplace.com/groups/1075192433118967/posts/1538435030128036/?comment_id=1556782068293332

I tried to make a test case out of the code linked in that github issue. The setup + bad outcome today was as follows:

(1) you have a graph where one of its inputs is a model weight

(2) in the backward, you do some downstream compute on `weight`, `tmp = f(weight)`, where (a) `tmp` is of a smaller size than `weight`, and (b) the compute is trivially fusible into other kernels (so the partitioner thinks it is "free" to recompute

(3) since `sizeof(tmp) < sizeof(weight)` and the recompute is free, the partitioner decides that it would be strictly better to save `tmp` for backward instead of weight

(4) this is bad: `weight` is a static tensor that sits in GPU memory for the duration of your entire training loop, so saving it for backward has no negative impact on peak memory.  Since we're saving `tmp` instead, we end up unnecessarily increasing peak memory. In particular - the repro involves an autograd.Function in eager that saves the weight for bw, so we end up hitting higher peak memory in compile

The fix I'm trying out in this PR is to tell the partitioner that graph inputs that we know have static addresses (aka parameters) are "free" to save.

Below is the fw/bw graph before my change, where you can see that instead of `primals_2` being saved for backward, we save `t_8` (which involves some low precision downstream compute on `primals_2`, that is only needed in the backward.

```
 ===== Forward graph 0 =====
 /data/users/hirsheybar/checkout2/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "bf16[64, 64][64, 1]cuda:0", primals_2: "bf16[64, 64][64, 1]cuda:0", primals_3: "bf16[64][1]cuda:0"):
         # File: /data/users/hirsheybar/checkout2/pytorch/test/dynamo/test_repros.py:6943 in forward, code: out = Fp8LinearFn.apply(
        abs_1: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.abs.default(primals_1)
        view: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(abs_1, [64, 1, 64]);  abs_1 = None
        amax: "bf16[64, 1][1, 1]cuda:0" = torch.ops.aten.amax.default(view, [-1]);  view = None
        abs_2: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.abs.default(primals_2)
        view_1: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(abs_2, [64, 1, 64]);  abs_2 = None
        amax_1: "bf16[64, 1][1, 1]cuda:0" = torch.ops.aten.amax.default(view_1, [-1]);  view_1 = None
        _to_copy: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten._to_copy.default(amax, dtype = torch.float32);  amax = None
        clamp: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.clamp.default(_to_copy, 1e-12);  _to_copy = None
        div: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.div.Tensor(clamp, 448.0);  clamp = None
        reciprocal: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.reciprocal.default(div)
        view_2: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(primals_1, [64, 1, 64])
        view_3: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_2, [64, 1, 1, 64]);  view_2 = None
        slice_1: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(reciprocal, 0, 0, 9223372036854775807);  reciprocal = None
        unsqueeze: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_1, 1);  slice_1 = None
        slice_2: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze, 2, 0, 9223372036854775807);  unsqueeze = None
        unsqueeze_1: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_2, 3);  slice_2 = None
        mul: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_3, unsqueeze_1);  view_3 = unsqueeze_1 = None
        view_4: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul, [64, 1, 64]);  mul = None
        view_5: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_4, [64, 64]);  view_4 = None
        _to_copy_1: "f8e4m3fn[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_5, dtype = torch.float8_e4m3fn);  view_5 = None
        _to_copy_2: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten._to_copy.default(amax_1, dtype = torch.float32)
        clamp_1: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.clamp.default(_to_copy_2, 1e-12);  _to_copy_2 = None
        div_1: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.div.Tensor(clamp_1, 448.0);  clamp_1 = None
        reciprocal_1: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.reciprocal.default(div_1)
        view_6: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(primals_2, [64, 1, 64])
        view_7: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_6, [64, 1, 1, 64]);  view_6 = None
        slice_3: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(reciprocal_1, 0, 0, 9223372036854775807);  reciprocal_1 = None
        unsqueeze_2: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_3, 1);  slice_3 = None
        slice_4: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_2, 2, 0, 9223372036854775807);  unsqueeze_2 = None
        unsqueeze_3: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_4, 3);  slice_4 = None
        mul_1: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_7, unsqueeze_3);  view_7 = unsqueeze_3 = None
        view_8: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul_1, [64, 1, 64]);  mul_1 = None
        view_9: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_8, [64, 64]);  view_8 = None
        _to_copy_3: "f8e4m3fn[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_9, dtype = torch.float8_e4m3fn);  view_9 = None
        t: "f32[1, 64][1, 1]cuda:0" = torch.ops.aten.t.default(div_1);  div_1 = None
        new_ones: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.new_ones.default(div, [1, 1], pin_memory = False)
        new_ones_1: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.new_ones.default(t, [1, 1], pin_memory = False)
        t_2: "f8e4m3fn[64, 64][1, 64]cuda:0" = torch.ops.aten.t.default(_to_copy_3);  _to_copy_3 = None
        t_3: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.t.default(new_ones_1);  new_ones_1 = None
        _scaled_mm: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten._scaled_mm.default(_to_copy_1, t_2, new_ones, t_3, None, None, torch.bfloat16);  _to_copy_1 = t_2 = new_ones = t_3 = None
        view_10: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(_scaled_mm, [64, 1, 64]);  _scaled_mm = None
        view_11: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_10, [64, 1, 1, 64]);  view_10 = None
        slice_5: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(div, 0, 0, 9223372036854775807);  div = None
        unsqueeze_4: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_5, 1);  slice_5 = None
        slice_6: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_4, 2, 0, 9223372036854775807);  unsqueeze_4 = None
        unsqueeze_5: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_6, 3);  slice_6 = None
        mul_2: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_11, unsqueeze_5);  view_11 = unsqueeze_5 = None
        view_12: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul_2, [64, 1, 64]);  mul_2 = None
        view_13: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_12, [64, 64]);  view_12 = None
        view_14: "f32[1, 64, 64][4096, 64, 1]cuda:0" = torch.ops.aten.view.default(view_13, [1, 64, 64]);  view_13 = None
        view_15: "f32[1, 64, 64, 1][4096, 64, 1, 1]cuda:0" = torch.ops.aten.view.default(view_14, [1, 64, 64, 1]);  view_14 = None
        slice_7: "f32[1, 64][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(t, 0, 0, 9223372036854775807);  t = None
        unsqueeze_6: "f32[1, 1, 64][1, 64, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_7, 1);  slice_7 = None
        slice_8: "f32[1, 1, 64][1, 64, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_6, 2, 0, 9223372036854775807);  unsqueeze_6 = None
        unsqueeze_7: "f32[1, 1, 64, 1][1, 64, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_8, 3);  slice_8 = None
        mul_3: "f32[1, 64, 64, 1][4096, 64, 1, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_15, unsqueeze_7);  view_15 = unsqueeze_7 = None
        view_16: "f32[64, 64, 1][64, 1, 1]cuda:0" = torch.ops.aten.view.default(mul_3, [64, 64, 1]);  mul_3 = None
        view_17: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_16, [64, 64]);  view_16 = None
        _to_copy_4: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_17, dtype = torch.bfloat16);  view_17 = None
        add: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.add.Tensor(_to_copy_4, primals_3);  _to_copy_4 = primals_3 = None
        t_4: "bf16[64, 64][1, 64]cuda:0" = torch.ops.aten.t.default(primals_2);  primals_2 = None
        clone: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.clone.default(t_4, memory_format = torch.contiguous_format);  t_4 = None
        t_5: "bf16[1, 64][1, 1]cuda:0" = torch.ops.aten.t.default(amax_1);  amax_1 = None
        view_21: "bf16[1, 1, 64][1, 64, 1]cuda:0" = torch.ops.aten.view.default(t_5, [1, 1, 64]);  t_5 = None
        amax_3: "bf16[1, 1][1, 1]cuda:0" = torch.ops.aten.amax.default(view_21, [-1]);  view_21 = None
        unsqueeze_8: "bf16[1, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(amax_3, 1);  amax_3 = None
        expand: "bf16[1, 64, 1][1, 0, 1]cuda:0" = torch.ops.aten.expand.default(unsqueeze_8, [1, 64, 1])
        clone_1: "bf16[1, 64, 1][64, 1, 1]cuda:0" = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format);  expand = None
        view_22: "bf16[64, 1][1, 1]cuda:0" = torch.ops.aten.view.default(clone_1, [64, 1]);  clone_1 = None
        _to_copy_7: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten._to_copy.default(view_22, dtype = torch.float32);  view_22 = None
        clamp_3: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.clamp.default(_to_copy_7, 1e-12);  _to_copy_7 = None
        div_3: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.div.Tensor(clamp_3, 448.0);  clamp_3 = None
        reciprocal_3: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.reciprocal.default(div_3);  div_3 = None
        view_27: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(clone, [64, 1, 64]);  clone = None
        view_28: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_27, [64, 1, 1, 64]);  view_27 = None
        slice_11: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(reciprocal_3, 0, 0, 9223372036854775807);  reciprocal_3 = None
        unsqueeze_11: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_11, 1);  slice_11 = None
        slice_12: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_11, 2, 0, 9223372036854775807);  unsqueeze_11 = None
        unsqueeze_12: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_12, 3);  slice_12 = None
        mul_5: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_28, unsqueeze_12);  view_28 = unsqueeze_12 = None
        view_29: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul_5, [64, 1, 64]);  mul_5 = None
        view_30: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_29, [64, 64]);  view_29 = None
        _to_copy_8: "f8e4m3fn[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_30, dtype = torch.float8_e4m3fn);  view_30 = None
        t_8: "f8e4m3fn[64, 64][1, 64]cuda:0" = torch.ops.aten.t.default(_to_copy_8);  _to_copy_8 = None

        # No stacktrace found for following nodes
        view_39: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(add, [64, 64]);  add = None
        return (view_39, primals_1, unsqueeze_8, t_8)

INFO: TRACED GRAPH
 ===== Backward graph 0 =====
 <eval_with_key>.1 class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "bf16[64, 64][64, 1]cuda:0", unsqueeze_8: "bf16[1, 1, 1][1, 1, 1]cuda:0", t_8: "f8e4m3fn[64, 64][1, 64]cuda:0", tangents_1: "bf16[64, 64][64, 1]cuda:0"):
         # File: /data/users/hirsheybar/checkout2/pytorch/test/dynamo/test_repros.py:6946 in forward, code: out = out.unflatten(0, input.shape[:-1])
        view_19: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(tangents_1, [64, 64]);  tangents_1 = None

         # File: /data/users/hirsheybar/checkout2/pytorch/test/dynamo/test_repros.py:6943 in forward, code: out = Fp8LinearFn.apply(
        abs_3: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.abs.default(view_19)
        view_20: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(abs_3, [64, 1, 64]);  abs_3 = None
        amax_2: "bf16[64, 1][1, 1]cuda:0" = torch.ops.aten.amax.default(view_20, [-1]);  view_20 = None
        expand: "bf16[1, 64, 1][1, 0, 1]cuda:0" = torch.ops.aten.expand.default(unsqueeze_8, [1, 64, 1]);  unsqueeze_8 = None
        clone_1: "bf16[1, 64, 1][64, 1, 1]cuda:0" = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format);  expand = None
        view_22: "bf16[64, 1][1, 1]cuda:0" = torch.ops.aten.view.default(clone_1, [64, 1]);  clone_1 = None
        _to_copy_5: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten._to_copy.default(amax_2, dtype = torch.float32);  amax_2 = None
        clamp_2: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.clamp.default(_to_copy_5, 1e-12);  _to_copy_5 = None
        div_2: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.div.Tensor(clamp_2, 448.0);  clamp_2 = None
        reciprocal_2: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.reciprocal.default(div_2)
        view_23: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_19, [64, 1, 64])
        view_24: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_23, [64, 1, 1, 64]);  view_23 = None
        slice_9: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(reciprocal_2, 0, 0, 9223372036854775807);  reciprocal_2 = None
        unsqueeze_9: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_9, 1);  slice_9 = None
        slice_10: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_9, 2, 0, 9223372036854775807);  unsqueeze_9 = None
        unsqueeze_10: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_10, 3);  slice_10 = None
        mul_4: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_24, unsqueeze_10);  view_24 = unsqueeze_10 = None
        view_25: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul_4, [64, 1, 64]);  mul_4 = None
        view_26: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_25, [64, 64]);  view_25 = None
        _to_copy_6: "f8e4m3fn[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_26, dtype = torch.float8_e4m3fn);  view_26 = None
        _to_copy_7: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten._to_copy.default(view_22, dtype = torch.float32);  view_22 = None
        clamp_3: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.clamp.default(_to_copy_7, 1e-12);  _to_copy_7 = None
        div_3: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.div.Tensor(clamp_3, 448.0);  clamp_3 = None
        t_6: "f32[1, 64][1, 1]cuda:0" = torch.ops.aten.t.default(div_3);  div_3 = None
        new_ones_2: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.new_ones.default(div_2, [1, 1], pin_memory = False)
        new_ones_3: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.new_ones.default(t_6, [1, 1], pin_memory = False)
        t_9: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.t.default(new_ones_3);  new_ones_3 = None
        _scaled_mm_1: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten._scaled_mm.default(_to_copy_6, t_8, new_ones_2, t_9, None, None, torch.bfloat16);  _to_copy_6 = t_8 = new_ones_2 = t_9 = None
        view_31: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(_scaled_mm_1, [64, 1, 64]);  _scaled_mm_1 = None
        view_32: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_31, [64, 1, 1, 64]);  view_31 = None
        slice_13: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(div_2, 0, 0, 9223372036854775807);  div_2 = None
        unsqueeze_13: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_13, 1);  slice_13 = None
        slice_14: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_13, 2, 0, 9223372036854775807);  unsqueeze_13 = None
        unsqueeze_14: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_14, 3);  slice_14 = None
        mul_6: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_32, unsqueeze_14);  view_32 = unsqueeze_14 = None
        view_33: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul_6, [64, 1, 64]);  mul_6 = None
        view_34: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_33, [64, 64]);  view_33 = None
        view_35: "f32[1, 64, 64][4096, 64, 1]cuda:0" = torch.ops.aten.view.default(view_34, [1, 64, 64]);  view_34 = None
        view_36: "f32[1, 64, 64, 1][4096, 64, 1, 1]cuda:0" = torch.ops.aten.view.default(view_35, [1, 64, 64, 1]);  view_35 = None
        slice_15: "f32[1, 64][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(t_6, 0, 0, 9223372036854775807);  t_6 = None
        unsqueeze_15: "f32[1, 1, 64][1, 64, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_15, 1);  slice_15 = None
        slice_16: "f32[1, 1, 64][1, 64, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_15, 2, 0, 9223372036854775807);  unsqueeze_15 = None
        unsqueeze_16: "f32[1, 1, 64, 1][1, 64, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_16, 3);  slice_16 = None
        mul_7: "f32[1, 64, 64, 1][4096, 64, 1, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_36, unsqueeze_16);  view_36 = unsqueeze_16 = None
        view_37: "f32[64, 64, 1][64, 1, 1]cuda:0" = torch.ops.aten.view.default(mul_7, [64, 64, 1]);  mul_7 = None
        view_38: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_37, [64, 64]);  view_37 = None
        _to_copy_9: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_38, dtype = torch.bfloat16);  view_38 = None
        t_10: "bf16[64, 64][1, 64]cuda:0" = torch.ops.aten.t.default(view_19)
        mm: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.mm.default(t_10, primals_1);  t_10 = primals_1 = None
        sum_1: "bf16[64][1]cuda:0" = torch.ops.aten.sum.dim_IntList(view_19, [0]);  view_19 = None
        return (_to_copy_9, mm, sum_1)

```

With the change, we save primals_2 for backward instead

```
 ===== Forward graph 0 =====
 /data/users/hirsheybar/checkout2/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "bf16[64, 64][64, 1]cuda:0", primals_2: "bf16[64, 64][64, 1]cuda:0", primals_3: "bf16[64][1]cuda:0"):
         # File: /data/users/hirsheybar/checkout2/pytorch/test/dynamo/test_repros.py:6943 in forward, code: out = Fp8LinearFn.apply(
        abs_1: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.abs.default(primals_1)
        view: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(abs_1, [64, 1, 64]);  abs_1 = None
        amax: "bf16[64, 1][1, 1]cuda:0" = torch.ops.aten.amax.default(view, [-1]);  view = None
        abs_2: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.abs.default(primals_2)
        view_1: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(abs_2, [64, 1, 64]);  abs_2 = None
        amax_1: "bf16[64, 1][1, 1]cuda:0" = torch.ops.aten.amax.default(view_1, [-1]);  view_1 = None
        _to_copy: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten._to_copy.default(amax, dtype = torch.float32);  amax = None
        clamp: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.clamp.default(_to_copy, 1e-12);  _to_copy = None
        div: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.div.Tensor(clamp, 448.0);  clamp = None
        reciprocal: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.reciprocal.default(div)
        view_2: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(primals_1, [64, 1, 64])
        view_3: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_2, [64, 1, 1, 64]);  view_2 = None
        slice_1: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(reciprocal, 0, 0, 9223372036854775807);  reciprocal = None
        unsqueeze: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_1, 1);  slice_1 = None
        slice_2: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze, 2, 0, 9223372036854775807);  unsqueeze = None
        unsqueeze_1: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_2, 3);  slice_2 = None
        mul: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_3, unsqueeze_1);  view_3 = unsqueeze_1 = None
        view_4: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul, [64, 1, 64]);  mul = None
        view_5: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_4, [64, 64]);  view_4 = None
        _to_copy_1: "f8e4m3fn[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_5, dtype = torch.float8_e4m3fn);  view_5 = None
        _to_copy_2: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten._to_copy.default(amax_1, dtype = torch.float32)
        clamp_1: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.clamp.default(_to_copy_2, 1e-12);  _to_copy_2 = None
        div_1: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.div.Tensor(clamp_1, 448.0);  clamp_1 = None
        reciprocal_1: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.reciprocal.default(div_1)
        view_6: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(primals_2, [64, 1, 64])
        view_7: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_6, [64, 1, 1, 64]);  view_6 = None
        slice_3: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(reciprocal_1, 0, 0, 9223372036854775807);  reciprocal_1 = None
        unsqueeze_2: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_3, 1);  slice_3 = None
        slice_4: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_2, 2, 0, 9223372036854775807);  unsqueeze_2 = None
        unsqueeze_3: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_4, 3);  slice_4 = None
        mul_1: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_7, unsqueeze_3);  view_7 = unsqueeze_3 = None
        view_8: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul_1, [64, 1, 64]);  mul_1 = None
        view_9: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_8, [64, 64]);  view_8 = None
        _to_copy_3: "f8e4m3fn[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_9, dtype = torch.float8_e4m3fn);  view_9 = None
        t: "f32[1, 64][1, 1]cuda:0" = torch.ops.aten.t.default(div_1);  div_1 = None
        new_ones: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.new_ones.default(div, [1, 1], pin_memory = False)
        new_ones_1: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.new_ones.default(t, [1, 1], pin_memory = False)
        t_2: "f8e4m3fn[64, 64][1, 64]cuda:0" = torch.ops.aten.t.default(_to_copy_3);  _to_copy_3 = None
        t_3: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.t.default(new_ones_1);  new_ones_1 = None
        _scaled_mm: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten._scaled_mm.default(_to_copy_1, t_2, new_ones, t_3, None, None, torch.bfloat16);  _to_copy_1 = t_2 = new_ones = t_3 = None
        view_10: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(_scaled_mm, [64, 1, 64]);  _scaled_mm = None
        view_11: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_10, [64, 1, 1, 64]);  view_10 = None
        slice_5: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(div, 0, 0, 9223372036854775807);  div = None
        unsqueeze_4: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_5, 1);  slice_5 = None
        slice_6: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_4, 2, 0, 9223372036854775807);  unsqueeze_4 = None
        unsqueeze_5: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_6, 3);  slice_6 = None
        mul_2: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_11, unsqueeze_5);  view_11 = unsqueeze_5 = None
        view_12: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul_2, [64, 1, 64]);  mul_2 = None
        view_13: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_12, [64, 64]);  view_12 = None
        view_14: "f32[1, 64, 64][4096, 64, 1]cuda:0" = torch.ops.aten.view.default(view_13, [1, 64, 64]);  view_13 = None
        view_15: "f32[1, 64, 64, 1][4096, 64, 1, 1]cuda:0" = torch.ops.aten.view.default(view_14, [1, 64, 64, 1]);  view_14 = None
        slice_7: "f32[1, 64][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(t, 0, 0, 9223372036854775807);  t = None
        unsqueeze_6: "f32[1, 1, 64][1, 64, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_7, 1);  slice_7 = None
        slice_8: "f32[1, 1, 64][1, 64, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_6, 2, 0, 9223372036854775807);  unsqueeze_6 = None
        unsqueeze_7: "f32[1, 1, 64, 1][1, 64, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_8, 3);  slice_8 = None
        mul_3: "f32[1, 64, 64, 1][4096, 64, 1, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_15, unsqueeze_7);  view_15 = unsqueeze_7 = None
        view_16: "f32[64, 64, 1][64, 1, 1]cuda:0" = torch.ops.aten.view.default(mul_3, [64, 64, 1]);  mul_3 = None
        view_17: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_16, [64, 64]);  view_16 = None
        _to_copy_4: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_17, dtype = torch.bfloat16);  view_17 = None
        add: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.add.Tensor(_to_copy_4, primals_3);  _to_copy_4 = primals_3 = None
        t_5: "bf16[1, 64][1, 1]cuda:0" = torch.ops.aten.t.default(amax_1);  amax_1 = None
        view_21: "bf16[1, 1, 64][1, 64, 1]cuda:0" = torch.ops.aten.view.default(t_5, [1, 1, 64]);  t_5 = None
        amax_3: "bf16[1, 1][1, 1]cuda:0" = torch.ops.aten.amax.default(view_21, [-1]);  view_21 = None
        unsqueeze_8: "bf16[1, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(amax_3, 1);  amax_3 = None

        # No stacktrace found for following nodes
        view_39: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(add, [64, 64]);  add = None
        return (view_39, primals_1, primals_2, unsqueeze_8)

INFO: TRACED GRAPH
 ===== Backward graph 0 =====
 <eval_with_key>.1 class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "bf16[64, 64][64, 1]cuda:0", primals_2: "bf16[64, 64][64, 1]cuda:0", unsqueeze_8: "bf16[1, 1, 1][1, 1, 1]cuda:0", tangents_1: "bf16[64, 64][64, 1]cuda:0"):
         # File: /data/users/hirsheybar/checkout2/pytorch/test/dynamo/test_repros.py:6946 in forward, code: out = out.unflatten(0, input.shape[:-1])
        view_19: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(tangents_1, [64, 64]);  tangents_1 = None

         # File: /data/users/hirsheybar/checkout2/pytorch/test/dynamo/test_repros.py:6943 in forward, code: out = Fp8LinearFn.apply(
        t_4: "bf16[64, 64][1, 64]cuda:0" = torch.ops.aten.t.default(primals_2);  primals_2 = None
        clone: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.clone.default(t_4, memory_format = torch.contiguous_format);  t_4 = None
        abs_3: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.abs.default(view_19)
        view_20: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(abs_3, [64, 1, 64]);  abs_3 = None
        amax_2: "bf16[64, 1][1, 1]cuda:0" = torch.ops.aten.amax.default(view_20, [-1]);  view_20 = None
        expand: "bf16[1, 64, 1][1, 0, 1]cuda:0" = torch.ops.aten.expand.default(unsqueeze_8, [1, 64, 1]);  unsqueeze_8 = None
        clone_1: "bf16[1, 64, 1][64, 1, 1]cuda:0" = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format);  expand = None
        view_22: "bf16[64, 1][1, 1]cuda:0" = torch.ops.aten.view.default(clone_1, [64, 1]);  clone_1 = None
        _to_copy_5: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten._to_copy.default(amax_2, dtype = torch.float32);  amax_2 = None
        clamp_2: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.clamp.default(_to_copy_5, 1e-12);  _to_copy_5 = None
        div_2: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.div.Tensor(clamp_2, 448.0);  clamp_2 = None
        reciprocal_2: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.reciprocal.default(div_2)
        view_23: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_19, [64, 1, 64])
        view_24: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_23, [64, 1, 1, 64]);  view_23 = None
        slice_9: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(reciprocal_2, 0, 0, 9223372036854775807);  reciprocal_2 = None
        unsqueeze_9: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_9, 1);  slice_9 = None
        slice_10: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_9, 2, 0, 9223372036854775807);  unsqueeze_9 = None
        unsqueeze_10: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_10, 3);  slice_10 = None
        mul_4: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_24, unsqueeze_10);  view_24 = unsqueeze_10 = None
        view_25: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul_4, [64, 1, 64]);  mul_4 = None
        view_26: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_25, [64, 64]);  view_25 = None
        _to_copy_6: "f8e4m3fn[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_26, dtype = torch.float8_e4m3fn);  view_26 = None
        _to_copy_7: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten._to_copy.default(view_22, dtype = torch.float32);  view_22 = None
        clamp_3: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.clamp.default(_to_copy_7, 1e-12);  _to_copy_7 = None
        div_3: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.div.Tensor(clamp_3, 448.0);  clamp_3 = None
        reciprocal_3: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.reciprocal.default(div_3)
        view_27: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(clone, [64, 1, 64]);  clone = None
        view_28: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_27, [64, 1, 1, 64]);  view_27 = None
        slice_11: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(reciprocal_3, 0, 0, 9223372036854775807);  reciprocal_3 = None
        unsqueeze_11: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_11, 1);  slice_11 = None
        slice_12: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_11, 2, 0, 9223372036854775807);  unsqueeze_11 = None
        unsqueeze_12: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_12, 3);  slice_12 = None
        mul_5: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_28, unsqueeze_12);  view_28 = unsqueeze_12 = None
        view_29: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul_5, [64, 1, 64]);  mul_5 = None
        view_30: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_29, [64, 64]);  view_29 = None
        _to_copy_8: "f8e4m3fn[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_30, dtype = torch.float8_e4m3fn);  view_30 = None
        t_6: "f32[1, 64][1, 1]cuda:0" = torch.ops.aten.t.default(div_3);  div_3 = None
        new_ones_2: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.new_ones.default(div_2, [1, 1], pin_memory = False)
        new_ones_3: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.new_ones.default(t_6, [1, 1], pin_memory = False)
        t_8: "f8e4m3fn[64, 64][1, 64]cuda:0" = torch.ops.aten.t.default(_to_copy_8);  _to_copy_8 = None
        t_9: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.t.default(new_ones_3);  new_ones_3 = None
        _scaled_mm_1: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten._scaled_mm.default(_to_copy_6, t_8, new_ones_2, t_9, None, None, torch.bfloat16);  _to_copy_6 = t_8 = new_ones_2 = t_9 = None
        view_31: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(_scaled_mm_1, [64, 1, 64]);  _scaled_mm_1 = None
        view_32: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_31, [64, 1, 1, 64]);  view_31 = None
        slice_13: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(div_2, 0, 0, 9223372036854775807);  div_2 = None
        unsqueeze_13: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_13, 1);  slice_13 = None
        slice_14: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_13, 2, 0, 9223372036854775807);  unsqueeze_13 = None
        unsqueeze_14: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_14, 3);  slice_14 = None
        mul_6: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_32, unsqueeze_14);  view_32 = unsqueeze_14 = None
        view_33: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul_6, [64, 1, 64]);  mul_6 = None
        view_34: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_33, [64, 64]);  view_33 = None
        view_35: "f32[1, 64, 64][4096, 64, 1]cuda:0" = torch.ops.aten.view.default(view_34, [1, 64, 64]);  view_34 = None
        view_36: "f32[1, 64, 64, 1][4096, 64, 1, 1]cuda:0" = torch.ops.aten.view.default(view_35, [1, 64, 64, 1]);  view_35 = None
        slice_15: "f32[1, 64][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(t_6, 0, 0, 9223372036854775807);  t_6 = None
        unsqueeze_15: "f32[1, 1, 64][1, 64, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_15, 1);  slice_15 = None
        slice_16: "f32[1, 1, 64][1, 64, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_15, 2, 0, 9223372036854775807);  unsqueeze_15 = None
        unsqueeze_16: "f32[1, 1, 64, 1][1, 64, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_16, 3);  slice_16 = None
        mul_7: "f32[1, 64, 64, 1][4096, 64, 1, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_36, unsqueeze_16);  view_36 = unsqueeze_16 = None
        view_37: "f32[64, 64, 1][64, 1, 1]cuda:0" = torch.ops.aten.view.default(mul_7, [64, 64, 1]);  mul_7 = None
        view_38: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_37, [64, 64]);  view_37 = None
        _to_copy_9: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_38, dtype = torch.bfloat16);  view_38 = None
        t_10: "bf16[64, 64][1, 64]cuda:0" = torch.ops.aten.t.default(view_19)
        mm: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.mm.default(t_10, primals_1);  t_10 = primals_1 = None
        sum_1: "bf16[64][1]cuda:0" = torch.ops.aten.sum.dim_IntList(view_19, [0]);  view_19 = None
        return (_to_copy_9, mm, sum_1)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148922
Approved by: https://github.com/zou3519
2025-03-18 20:08:11 +00:00
Aaron Gokaslan
a0ac63cbd9 [BE]: Apply ruff PERF403 to use dict comprehensions more often (#149257)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149257
Approved by: https://github.com/jansel
2025-03-18 00:46:07 +00:00
PyTorch MergeBot
24cfeec2c7 Revert "[BE]: Apply ruff PERF403 to use dict comprehensions more often (#149257)"
This reverts commit bfee141666.

Reverted https://github.com/pytorch/pytorch/pull/149257 on behalf of https://github.com/malfet due to Let's see if it helps restore compiler benchmark sanity, see 8bc7bd94a5/1 ([comment](https://github.com/pytorch/pytorch/pull/149257#issuecomment-2731133812))
2025-03-17 22:57:00 +00:00
Aaron Gokaslan
bfee141666 [BE]: Apply ruff PERF403 to use dict comprehensions more often (#149257)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149257
Approved by: https://github.com/jansel
2025-03-16 23:52:58 +00:00
Brian Hirsh
3646d4dbc8 [partitioner] always ban compiler-driven recompute of collectives by default (#147561)
This should fix the hang in https://fb.workplace.com/groups/1075192433118967/permalink/1603268720311333/

The argument here is that:

(1) in general, it is not safe for the partitioner to sometimes choose to recompute collectives in the backward. Why? If we are running a distributed job, where many ranks are compiling at the same time, we need every rank to make a consistent decision about which collectives are recomputed for backward. If we let each compiler instance make its own choice without any cross-rank communication, they can make different choices and cause NCCL hangs (see the link above)

(2) later on, we'll want an `spmd_mode` flag that causes the compiler to issue collectives and communicate info across ranks. Once we have such a config, then turning it on should make it safe for the partitioner to potentially choose to recompute collectives (and agree on the binary "recompute-or-save" choice across all ranks)

(3) even without an `spmd_mode`, users can override this choice by using `torch.utils.checkpoint()` in their user code. User checkpointing generally always overrides the partitioner, and this should be safe because we expect the user to apply checkpointing consistently across ranks

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147561
Approved by: https://github.com/zou3519
2025-03-13 03:36:13 +00:00
Brian Hirsh
621dadd4ca partitioner: when materializing unbacked tensor intermediates, apply hint to symbol, not expr (#144097)
Fixes https://github.com/pytorch/pytorch/issues/144095

open to suggestions: the `hint_int(..., fallback=...)` API feels like a bit of a footgun, because:

(1) we use the same guess for every unbacked symint (both symbols, and compound expressions)
(2) the user may have established some relationship between some unbacked symints that we are not taking into account.

I'm not sure how real of an issue (2) is - is it common to e.g. generate two unbacked symints, and then add a runtime assert that they are unequal?

Instead I did something simpler that's just enough to fix the linked issue: if we have a sympy expression containing an unbacked symbol (e.g. `u0 + 1`), then the partitioner will now fill in the symbol with our guess instead of the expression (plugging in `u0=4096` gets us 4097). This was important for an internal custom op, that had some logic like this:
```
def custom_op(x: [u0], y: [u0 + 1]):
    assert x.shape[0] = y.shape[0] - 1
    ...
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144097
Approved by: https://github.com/laithsakka
2025-03-11 02:11:57 +00:00
Simon Fan
666508eb17 [aot cache][ca] remove restriction on caching ca's aot inference graph (#148491)
but still can't cache CA's aot inference graph yet: the CA functional ops aren't serializable

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148491
Approved by: https://github.com/jamesjwu
ghstack dependencies: #148381
2025-03-08 06:08:26 +00:00
Simon Fan
c16cd25cf5 [ca] remove compiled_autograd_tracing (#148381)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148381
Approved by: https://github.com/jansel
2025-03-08 06:08:26 +00:00
Luca Wehrstedt
f80aad62fa Improve Pareto frontier plot for AutoAC (#148678)
This was added in https://github.com/pytorch/pytorch/pull/126320. It's a very nice feature, which can be used to predict memory usage for different budget values.

However, it had some limitations, notably in terms of resolution (it only sampled 21 points across the whole range thus missed many threshold values) and in distributed settings.

Here I fix those by using recursive binary searches to identify all thresholds (up to a resolution of 1e-3, which can be made configurable) and output them in SVG (to be able to discern different points), plus I add the rank to the filename and store it in a user-define directory.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148678
Approved by: https://github.com/Chillee, https://github.com/fmassa
2025-03-07 13:22:29 +00:00
Animesh Jain
d43c6f0033 [invoke_subgraph] Run joint passes on the hop graphs (#139325)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139325
Approved by: https://github.com/bdhirsh, https://github.com/zou3519
ghstack dependencies: #147559
2025-03-03 23:38:14 +00:00
Animesh Jain
eb9c127341 [dynamo][optimizers] Install ID_GUARDED tensors into the Fx graph (#147824)
Earlier, with inline flag we were lifting id-guarded tensors to the inputs to the Fx graph. But this offers no benefit. Main idea behind lifting parameters as inputs was to reuse the compilation units across many instances of the nn-module. However, if we are guarding on the `id`, we are explicitly specializing the compiled artifact to the parameter.

This PR installs the parameters back into the graph. The benefit is removal of all pre-graph bytecode to extract the id-guarded tensors from locals/globals. This increases speedup from 1.67x to 1.75x for an internal model that has large number of optimizer parameters.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147824
Approved by: https://github.com/jansel

Co-authored-by: Jason Ansel <jansel@meta.com>
2025-02-28 03:22:11 +00:00
eellison
481a57bc37 Support torch.compile rng selective activation checkpointing with cudagraph (#146878)
TODO:
- [x]  Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync
- [x] Update rng state initialization to take from correct device
- [x]  Tests
- [x] handling of retain_graph
- [x] respect fallback random

Fix for https://github.com/pytorch/pytorch/issues/130123.

Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states.

We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward.

```
 ===== Forward graph 1 =====
 /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0);  fwd_rng_state_0 = None
        ...

 ===== Backward graph 1 =====
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0);  bwd_rng_state_0 = None
```

There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls:
- fwd0: fwd_rng_state0 -> fwd_rng_state1
- fwd1: fwd_rng_state1 -> fwd_rng_state2
- bwd1
- bwd0

Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary.

Other notes:

Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order.

Questions for reviewers:

This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`.

Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set

I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts.

Edit: updated to be taken from randint()

Update: initializing rng states from torch.randint..

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
2025-02-28 00:47:03 +00:00
IvanKobzarev
7ae0e0b2ea [aotd] Log torch._functorch.config in tlparse (#147883)
Adding torch._functorch.config to tlparse for better debugability.
E.g. https://github.com/pytorch/pytorch/pull/147638 happened only with `torch._functorch.config.view_replay_for_aliased_outputs=False` which is True by defautl

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147883
Approved by: https://github.com/bdhirsh, https://github.com/jamesjwu
2025-02-27 11:22:45 +00:00
PyTorch MergeBot
17358ce778 Revert "Support torch.compile rng selective activation checkpointing with cudagraph (#146878)"
This reverts commit ad0c879e22.

Reverted https://github.com/pytorch/pytorch/pull/146878 on behalf of https://github.com/wdvr due to lint failure ([comment](https://github.com/pytorch/pytorch/pull/146878#issuecomment-2686767956))
2025-02-27 03:36:16 +00:00
eellison
ad0c879e22 Support torch.compile rng selective activation checkpointing with cudagraph (#146878)
TODO:
- [x]  Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync
- [x] Update rng state initialization to take from correct device
- [x]  Tests
- [x] handling of retain_graph
- [x] respect fallback random

Fix for https://github.com/pytorch/pytorch/issues/130123.

Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states.

We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward.

```
 ===== Forward graph 1 =====
 /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0);  fwd_rng_state_0 = None
        ...

 ===== Backward graph 1 =====
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0);  bwd_rng_state_0 = None
```

There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls:
- fwd0: fwd_rng_state0 -> fwd_rng_state1
- fwd1: fwd_rng_state1 -> fwd_rng_state2
- bwd1
- bwd0

Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary.

Other notes:

Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order.

Questions for reviewers:

This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`.

Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set

I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts.

Edit: updated to be taken from randint()

Update: initializing rng states from torch.randint..

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
2025-02-27 02:08:29 +00:00
IvanKobzarev
8594856651 [aotd] Alias of intermediate unwrap TensorAlias (#147638)
Bug was reported by internal user.

AOTD classified outputs that are aliases of intermediates of the graph in different categories.

...
- output is alias of intermediate which base is already output
- output is alias of intermediate which base is not in output

If we look at the fn:
```
def fn(x):
    ix = x + 1
    a = ix.transpose(0, 1)
    return a.detach(), a
```

output 0: detach view of alias a, where a is already output
output 1: alias of intermediate ix, then additional output ix will be added internally

output 0 base is TensorAlias(a) in this case, but could be Tensor.
Adding runtime unwrapping solves this problem.

Alternatively we should track base of a.detach() all the way to ix, in that case the base will be always a Tensor, not TensorAlias.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147638
Approved by: https://github.com/bdhirsh
2025-02-26 19:42:21 +00:00
Brian Hirsh
89b9c12de8 remove prints from partitioner (#147749)
See c57894cd74..22d8f9a657 (r1968015955)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147749
Approved by: https://github.com/Skylion007, https://github.com/laithsakka
2025-02-24 21:03:45 +00:00
Aaron Orenstein
db4ce78d46 PEP585: More UP006 fixes (#146392)
This should be the final PR before we can enable RUFF UP006.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146392
Approved by: https://github.com/justinchuby, https://github.com/albanD, https://github.com/Skylion007
2025-02-20 06:18:13 +00:00
Laith Sakka
454fbd5bbe realize stride symbols in estimate_runtime (#146752)
Unfortuanlty could not create a local repo, or unit test.
fix https://github.com/pytorch/pytorch/issues/146686

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146752
Approved by: https://github.com/bobrenjc93, https://github.com/bdhirsh
2025-02-19 06:02:49 +00:00
Basil Wong
05001f0459 Add Structured Tracing for Traced Graph Edge Details for AC Debugging (#146634)
Summary:
Updating the structured trace infrastructure so that we are able to output to Zoomer and have an E2E solution.

Context Doc: https://docs.google.com/document/d/1T6omIBEWVhbOiwDLSLffgQwjxiT2rQv8QvvQwXkw4fY/edit?usp=sharing

Test Plan:
### Testing Structured Log + tlparse locally

Command:
```
TORCH_TRACE=/data/users/basilwong/fbsource/fbcode/log_torch_trace buck2 run mode/opt //aps_models/ads/icvr:icvr_launcher -- mode=local_fb_fm_v4 launcher.num_workers=2
```

Torch Trace Logs (local then sent to paste): P1686419449
```
cat log_torch_trace/dedicated_log_torch_trace_rank_0_2lg012xo.log | pastry
P1686419449
```

tlparse output: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpyiv5wj/rank_1/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=100

tlparse graph edge details output: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpyiv5wj/rank_1/9_0_0/joint_graph_information_397.txt?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=100

Differential Revision: D61557220

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146634
Approved by: https://github.com/jansel, https://github.com/Yuzhen11
2025-02-14 02:04:26 +00:00
Brian Hirsh
447a142de2 support input mutations on tangents in compile (#141131)
Fixes https://github.com/pytorch/pytorch/issues/141111. We previously supported mutations on saved activations that happened in the backward. This PR extends the support to tangents

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141131
Approved by: https://github.com/zou3519
2025-02-13 17:48:56 +00:00
Oguz Ulgen
076215944a Turn on autograd local caches in fbcode (#146996)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146996
Approved by: https://github.com/jamesjwu
2025-02-12 23:04:39 +00:00
Simon Fan
387c993c3b [ca] remove private API: _compiled_autograd_should_lift (#146720)
Since the functional autograd + compiled autograd migration, we don't trace into nodes anymore, and everything is lifted. We can't support this flag which tries to inline make_fx style in CA initial pass. There's no more usage internally.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146720
Approved by: https://github.com/zou3519
2025-02-10 04:29:57 +00:00
rzou
15b1ac3e86 Add torch.func.debug_unwrap (#146528)
Use it to unwrap any functorch-wrapped tensor. I don't recommend using
the output in a program since it breaks the semantics of the transforms,
but it seems useful for debugging.

I will note that some people have wanted to get intermediate values out
of an e.g. grad transform, so this might be a way to do that...

Test Plan:
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146528
Approved by: https://github.com/Chillee
2025-02-06 18:48:09 +00:00
Aaron Gokaslan
292af3cc89 [BE][Ez]: ISC001 Auto concatenate implicit one line strings (#146408)
Apply ruff rule about implicit string concatenation, this autofixes strings that are all the same type and on the same line. These lines are broken up likely as the result of autoformatters in the past. All fixes are automated using the autofixes in ISC001.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146408
Approved by: https://github.com/justinchuby, https://github.com/janeyx99
2025-02-04 19:07:04 +00:00
Sam Larsen
23fffb54d5 Use OrderedSet in _functorch/partitioners (#146102)
In an attempt to make partitioning more deterministic, change all sets in partitioners.py to OrderedSets. Note that this change does not fix the non-determinism we're seeing in the internal model. But let's at least eliminate this potential source of non-determinism before investigating any changes to the mincut approach?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146102
Approved by: https://github.com/oulgen
2025-02-04 17:43:07 +00:00
Animesh Jain
487400f47f [dynamo] Support functools.partial variables through inspect.signature (#146339)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146339
Approved by: https://github.com/jansel
ghstack dependencies: #146322, #146116
2025-02-04 04:39:39 +00:00
Tugsbayasgalan Manlaibaatar
041e08f9dc Add buffers to parameterizaiton rule (#145991)
Differential Revision: [D68959513](https://our.internmc.facebook.com/intern/diff/D68959513)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145991
Approved by: https://github.com/bdhirsh
2025-02-03 16:49:03 +00:00
Aaron Orenstein
ccbbc88bbb Turn on mypy for _dynamo/variables/builtin.py (#145552)
The fact that mypy errors were ignored was hiding several bugs in builtin.py (for example the previous diff's incorrect override and use of `call_getattr`)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145552
Approved by: https://github.com/anijain2305, https://github.com/Skylion007
ghstack dependencies: #145551
2025-01-30 22:21:32 +00:00
James Wu
d0aa1386b8 Disable AOTAutogradCache for triton version < 3.2 (#145937)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145937
Approved by: https://github.com/bdhirsh
2025-01-29 21:32:16 +00:00
Brian Hirsh
ed141d7d1a dont assign a size to _assert_scalar in partitioner (#143877)
Fixes https://github.com/pytorch/pytorch/issues/143876

Open to other suggestions - we have an invariant that all nodes in our ATen graphs should have a `meta['val']` field, but I don't think this is actually true in all cases, so I just hardcoded the invariant to ignore `_assert_scalar()` (which is a "special" op used in dynamic shapes for runtime asserts, and doesn't have a meta['val'] field)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143877
Approved by: https://github.com/zou3519
2025-01-29 16:21:37 +00:00
Brian Hirsh
7ca156f0ee partitioner: avoid inserting duplicates into heap (#145082)
Fixes https://github.com/pytorch/pytorch/issues/145081

This looks like it was a source of quadratic compile times in the torchtitan CP graphs. There's some code in the partitioner that iteratively adds users of a node to a heap, and pops the earliest user. If you have long parallel chains of fusible ops that all eventually feed into some shared ops, then this can result in:
(1) a node getting added to the heap many times
(2) each time we pop that node, we add (duplicates of) each of that node users to the heap
(3) repeat with each user

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145082
Approved by: https://github.com/xmfan
2025-01-28 23:44:45 +00:00
James Wu
d9ffa5da65 Log info for AOTAutogradCache bypasses instead of warning (#145768)
Fixes #145767

FxGraphCache also logs to info instead of warning so lets do that

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145768
Approved by: https://github.com/eellison, https://github.com/bdhirsh
2025-01-28 19:25:36 +00:00
James Wu
7c1fc0a047 Log cache state for AOTAutograd in title of file (#145715)
Differential Revision: [D68692755](https://our.internmc.facebook.com/intern/diff/D68692755/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145715
Approved by: https://github.com/bobrenjc93
2025-01-28 02:14:18 +00:00
rzou
ea141d8134 functional compiled autograd (#144707)
This PR squashes together the following commits:

https://github.com/pytorch/pytorch/pull/144115
https://github.com/pytorch/pytorch/pull/143417
https://github.com/pytorch/pytorch/pull/143405
https://github.com/pytorch/pytorch/pull/143387
https://github.com/pytorch/pytorch/pull/143304
https://github.com/pytorch/pytorch/pull/143296

This is a refactor of compiled autograd to use "functional autograd". The end goal is that it gets compiled autograd's initial capture to stop specializing on Tensor metadata, therefore allowing compiled autograd to better handle Tensor subclasses.

For more information, please read the commit messages for each PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144707
Approved by: https://github.com/bdhirsh, https://github.com/xmfan, https://github.com/jansel
2025-01-27 05:20:56 +00:00
Edward Z. Yang
90448f0128 Output of nonzero is transposed, fix fake tensor (#144695)
Needs this companion executorch PR: https://github.com/pytorch/executorch/pull/7657

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144695
Approved by: https://github.com/bobrenjc93, https://github.com/albanD
2025-01-26 01:07:22 +00:00
PyTorch MergeBot
16c4f8c395 Revert "[compiled autograd] Always proxy autograd.Function nodes; handle AOT backwards (#143405)"
This reverts commit ec820fe57c.

Reverted https://github.com/pytorch/pytorch/pull/143405 on behalf of https://github.com/izaitsevfb due to breaking internal tests T213390054 ([comment](https://github.com/pytorch/pytorch/pull/143296#issuecomment-2611224926))
2025-01-23 23:34:13 +00:00
PyTorch MergeBot
ab082863a1 Revert "[compiled autograd] support Tensor Subclasses in AOTBackward (#144115)"
This reverts commit 082c28c3c6.

Reverted https://github.com/pytorch/pytorch/pull/144115 on behalf of https://github.com/izaitsevfb due to breaking internal tests T213390054 ([comment](https://github.com/pytorch/pytorch/pull/143296#issuecomment-2611224926))
2025-01-23 23:34:12 +00:00
Aaron Gokaslan
5ebca3015d [BE]: Simplify set add with set update (#145152)
Simplifies the set update slightly to be more readable and efficient.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145152
Approved by: https://github.com/XuehaiPan, https://github.com/albanD

Co-authored-by: Xuehai Pan <XuehaiPan@outlook.com>
2025-01-23 20:18:13 +00:00