Commit Graph

153 Commits

Author SHA1 Message Date
Brian Hirsh
7f836b747f partitioner: ensure collectives saved by SAC that are actually unused in the bw are properly not saved (#149652)
This PR fixes one of the issues described here: https://github.com/pytorch/torchtitan/issues/866#issuecomment-2726015248

I spent some time trying to write a unit test and ultimately failed. If folks are interested I can spend more time trying to, but otherwise I have an E2E test with torchtitan. command:
```
CUDA_VISIBLE_DEVICES=1,2,3,4 NGPU=4 CONFIG_FILE="./torchtitan/models/llama/train_configs/llama3_8b.toml" tlp ./run_train.sh --training.steps=30  --training.tensor_parallel_degree=2 --training.compile --experimental.enable_async_tensor_parallel
```

here's the backward graph generated prior to the PR: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/hirsheybar/f7d17388-42c2-4d7e-8a55-a00387341ecb/custom/rank_0/-_0_0_0/aot_backward_graph_9.txt?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000

and new backward graph with the PR: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/hirsheybar/ab8576fc-98c1-4915-af47-699aa8e2557e/custom/rank_0/-_0_0_0/aot_backward_graph_9.txt?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000

The main difference is that the input arg `reduce_scatter_tensor_1` is dead code in the bw graph, causing us to unnecessarily save a giant `reduce_scatter` for bw. With the PR, we properly ensure that it is not saved for backward.

More comments in the PR, but the main thing going on is that:

(1) We have some existing logic that checks for activations that are actually dead code in the backward, and removes them

(2) collectives are not properly handled by this code. Why? collective are **always** followed by  `wait_tensor()` call. So we need to go one node further and check if the "dead" code has a wait_tensor user that is also dead

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149652
Approved by: https://github.com/zou3519
ghstack dependencies: #149514
2025-03-21 22:09:19 +00:00
Brian Hirsh
f06e366532 partitioner: treat inputs with static indices as free to save (#148922)
Fixes https://github.com/pytorch/pytorch/issues/141881

internal xref: https://fb.workplace.com/groups/1075192433118967/posts/1538435030128036/?comment_id=1556782068293332

I tried to make a test case out of the code linked in that github issue. The setup + bad outcome today was as follows:

(1) you have a graph where one of its inputs is a model weight

(2) in the backward, you do some downstream compute on `weight`, `tmp = f(weight)`, where (a) `tmp` is of a smaller size than `weight`, and (b) the compute is trivially fusible into other kernels (so the partitioner thinks it is "free" to recompute

(3) since `sizeof(tmp) < sizeof(weight)` and the recompute is free, the partitioner decides that it would be strictly better to save `tmp` for backward instead of weight

(4) this is bad: `weight` is a static tensor that sits in GPU memory for the duration of your entire training loop, so saving it for backward has no negative impact on peak memory.  Since we're saving `tmp` instead, we end up unnecessarily increasing peak memory. In particular - the repro involves an autograd.Function in eager that saves the weight for bw, so we end up hitting higher peak memory in compile

The fix I'm trying out in this PR is to tell the partitioner that graph inputs that we know have static addresses (aka parameters) are "free" to save.

Below is the fw/bw graph before my change, where you can see that instead of `primals_2` being saved for backward, we save `t_8` (which involves some low precision downstream compute on `primals_2`, that is only needed in the backward.

```
 ===== Forward graph 0 =====
 /data/users/hirsheybar/checkout2/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "bf16[64, 64][64, 1]cuda:0", primals_2: "bf16[64, 64][64, 1]cuda:0", primals_3: "bf16[64][1]cuda:0"):
         # File: /data/users/hirsheybar/checkout2/pytorch/test/dynamo/test_repros.py:6943 in forward, code: out = Fp8LinearFn.apply(
        abs_1: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.abs.default(primals_1)
        view: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(abs_1, [64, 1, 64]);  abs_1 = None
        amax: "bf16[64, 1][1, 1]cuda:0" = torch.ops.aten.amax.default(view, [-1]);  view = None
        abs_2: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.abs.default(primals_2)
        view_1: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(abs_2, [64, 1, 64]);  abs_2 = None
        amax_1: "bf16[64, 1][1, 1]cuda:0" = torch.ops.aten.amax.default(view_1, [-1]);  view_1 = None
        _to_copy: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten._to_copy.default(amax, dtype = torch.float32);  amax = None
        clamp: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.clamp.default(_to_copy, 1e-12);  _to_copy = None
        div: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.div.Tensor(clamp, 448.0);  clamp = None
        reciprocal: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.reciprocal.default(div)
        view_2: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(primals_1, [64, 1, 64])
        view_3: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_2, [64, 1, 1, 64]);  view_2 = None
        slice_1: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(reciprocal, 0, 0, 9223372036854775807);  reciprocal = None
        unsqueeze: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_1, 1);  slice_1 = None
        slice_2: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze, 2, 0, 9223372036854775807);  unsqueeze = None
        unsqueeze_1: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_2, 3);  slice_2 = None
        mul: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_3, unsqueeze_1);  view_3 = unsqueeze_1 = None
        view_4: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul, [64, 1, 64]);  mul = None
        view_5: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_4, [64, 64]);  view_4 = None
        _to_copy_1: "f8e4m3fn[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_5, dtype = torch.float8_e4m3fn);  view_5 = None
        _to_copy_2: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten._to_copy.default(amax_1, dtype = torch.float32)
        clamp_1: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.clamp.default(_to_copy_2, 1e-12);  _to_copy_2 = None
        div_1: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.div.Tensor(clamp_1, 448.0);  clamp_1 = None
        reciprocal_1: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.reciprocal.default(div_1)
        view_6: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(primals_2, [64, 1, 64])
        view_7: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_6, [64, 1, 1, 64]);  view_6 = None
        slice_3: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(reciprocal_1, 0, 0, 9223372036854775807);  reciprocal_1 = None
        unsqueeze_2: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_3, 1);  slice_3 = None
        slice_4: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_2, 2, 0, 9223372036854775807);  unsqueeze_2 = None
        unsqueeze_3: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_4, 3);  slice_4 = None
        mul_1: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_7, unsqueeze_3);  view_7 = unsqueeze_3 = None
        view_8: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul_1, [64, 1, 64]);  mul_1 = None
        view_9: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_8, [64, 64]);  view_8 = None
        _to_copy_3: "f8e4m3fn[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_9, dtype = torch.float8_e4m3fn);  view_9 = None
        t: "f32[1, 64][1, 1]cuda:0" = torch.ops.aten.t.default(div_1);  div_1 = None
        new_ones: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.new_ones.default(div, [1, 1], pin_memory = False)
        new_ones_1: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.new_ones.default(t, [1, 1], pin_memory = False)
        t_2: "f8e4m3fn[64, 64][1, 64]cuda:0" = torch.ops.aten.t.default(_to_copy_3);  _to_copy_3 = None
        t_3: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.t.default(new_ones_1);  new_ones_1 = None
        _scaled_mm: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten._scaled_mm.default(_to_copy_1, t_2, new_ones, t_3, None, None, torch.bfloat16);  _to_copy_1 = t_2 = new_ones = t_3 = None
        view_10: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(_scaled_mm, [64, 1, 64]);  _scaled_mm = None
        view_11: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_10, [64, 1, 1, 64]);  view_10 = None
        slice_5: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(div, 0, 0, 9223372036854775807);  div = None
        unsqueeze_4: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_5, 1);  slice_5 = None
        slice_6: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_4, 2, 0, 9223372036854775807);  unsqueeze_4 = None
        unsqueeze_5: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_6, 3);  slice_6 = None
        mul_2: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_11, unsqueeze_5);  view_11 = unsqueeze_5 = None
        view_12: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul_2, [64, 1, 64]);  mul_2 = None
        view_13: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_12, [64, 64]);  view_12 = None
        view_14: "f32[1, 64, 64][4096, 64, 1]cuda:0" = torch.ops.aten.view.default(view_13, [1, 64, 64]);  view_13 = None
        view_15: "f32[1, 64, 64, 1][4096, 64, 1, 1]cuda:0" = torch.ops.aten.view.default(view_14, [1, 64, 64, 1]);  view_14 = None
        slice_7: "f32[1, 64][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(t, 0, 0, 9223372036854775807);  t = None
        unsqueeze_6: "f32[1, 1, 64][1, 64, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_7, 1);  slice_7 = None
        slice_8: "f32[1, 1, 64][1, 64, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_6, 2, 0, 9223372036854775807);  unsqueeze_6 = None
        unsqueeze_7: "f32[1, 1, 64, 1][1, 64, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_8, 3);  slice_8 = None
        mul_3: "f32[1, 64, 64, 1][4096, 64, 1, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_15, unsqueeze_7);  view_15 = unsqueeze_7 = None
        view_16: "f32[64, 64, 1][64, 1, 1]cuda:0" = torch.ops.aten.view.default(mul_3, [64, 64, 1]);  mul_3 = None
        view_17: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_16, [64, 64]);  view_16 = None
        _to_copy_4: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_17, dtype = torch.bfloat16);  view_17 = None
        add: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.add.Tensor(_to_copy_4, primals_3);  _to_copy_4 = primals_3 = None
        t_4: "bf16[64, 64][1, 64]cuda:0" = torch.ops.aten.t.default(primals_2);  primals_2 = None
        clone: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.clone.default(t_4, memory_format = torch.contiguous_format);  t_4 = None
        t_5: "bf16[1, 64][1, 1]cuda:0" = torch.ops.aten.t.default(amax_1);  amax_1 = None
        view_21: "bf16[1, 1, 64][1, 64, 1]cuda:0" = torch.ops.aten.view.default(t_5, [1, 1, 64]);  t_5 = None
        amax_3: "bf16[1, 1][1, 1]cuda:0" = torch.ops.aten.amax.default(view_21, [-1]);  view_21 = None
        unsqueeze_8: "bf16[1, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(amax_3, 1);  amax_3 = None
        expand: "bf16[1, 64, 1][1, 0, 1]cuda:0" = torch.ops.aten.expand.default(unsqueeze_8, [1, 64, 1])
        clone_1: "bf16[1, 64, 1][64, 1, 1]cuda:0" = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format);  expand = None
        view_22: "bf16[64, 1][1, 1]cuda:0" = torch.ops.aten.view.default(clone_1, [64, 1]);  clone_1 = None
        _to_copy_7: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten._to_copy.default(view_22, dtype = torch.float32);  view_22 = None
        clamp_3: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.clamp.default(_to_copy_7, 1e-12);  _to_copy_7 = None
        div_3: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.div.Tensor(clamp_3, 448.0);  clamp_3 = None
        reciprocal_3: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.reciprocal.default(div_3);  div_3 = None
        view_27: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(clone, [64, 1, 64]);  clone = None
        view_28: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_27, [64, 1, 1, 64]);  view_27 = None
        slice_11: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(reciprocal_3, 0, 0, 9223372036854775807);  reciprocal_3 = None
        unsqueeze_11: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_11, 1);  slice_11 = None
        slice_12: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_11, 2, 0, 9223372036854775807);  unsqueeze_11 = None
        unsqueeze_12: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_12, 3);  slice_12 = None
        mul_5: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_28, unsqueeze_12);  view_28 = unsqueeze_12 = None
        view_29: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul_5, [64, 1, 64]);  mul_5 = None
        view_30: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_29, [64, 64]);  view_29 = None
        _to_copy_8: "f8e4m3fn[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_30, dtype = torch.float8_e4m3fn);  view_30 = None
        t_8: "f8e4m3fn[64, 64][1, 64]cuda:0" = torch.ops.aten.t.default(_to_copy_8);  _to_copy_8 = None

        # No stacktrace found for following nodes
        view_39: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(add, [64, 64]);  add = None
        return (view_39, primals_1, unsqueeze_8, t_8)

INFO: TRACED GRAPH
 ===== Backward graph 0 =====
 <eval_with_key>.1 class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "bf16[64, 64][64, 1]cuda:0", unsqueeze_8: "bf16[1, 1, 1][1, 1, 1]cuda:0", t_8: "f8e4m3fn[64, 64][1, 64]cuda:0", tangents_1: "bf16[64, 64][64, 1]cuda:0"):
         # File: /data/users/hirsheybar/checkout2/pytorch/test/dynamo/test_repros.py:6946 in forward, code: out = out.unflatten(0, input.shape[:-1])
        view_19: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(tangents_1, [64, 64]);  tangents_1 = None

         # File: /data/users/hirsheybar/checkout2/pytorch/test/dynamo/test_repros.py:6943 in forward, code: out = Fp8LinearFn.apply(
        abs_3: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.abs.default(view_19)
        view_20: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(abs_3, [64, 1, 64]);  abs_3 = None
        amax_2: "bf16[64, 1][1, 1]cuda:0" = torch.ops.aten.amax.default(view_20, [-1]);  view_20 = None
        expand: "bf16[1, 64, 1][1, 0, 1]cuda:0" = torch.ops.aten.expand.default(unsqueeze_8, [1, 64, 1]);  unsqueeze_8 = None
        clone_1: "bf16[1, 64, 1][64, 1, 1]cuda:0" = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format);  expand = None
        view_22: "bf16[64, 1][1, 1]cuda:0" = torch.ops.aten.view.default(clone_1, [64, 1]);  clone_1 = None
        _to_copy_5: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten._to_copy.default(amax_2, dtype = torch.float32);  amax_2 = None
        clamp_2: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.clamp.default(_to_copy_5, 1e-12);  _to_copy_5 = None
        div_2: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.div.Tensor(clamp_2, 448.0);  clamp_2 = None
        reciprocal_2: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.reciprocal.default(div_2)
        view_23: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_19, [64, 1, 64])
        view_24: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_23, [64, 1, 1, 64]);  view_23 = None
        slice_9: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(reciprocal_2, 0, 0, 9223372036854775807);  reciprocal_2 = None
        unsqueeze_9: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_9, 1);  slice_9 = None
        slice_10: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_9, 2, 0, 9223372036854775807);  unsqueeze_9 = None
        unsqueeze_10: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_10, 3);  slice_10 = None
        mul_4: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_24, unsqueeze_10);  view_24 = unsqueeze_10 = None
        view_25: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul_4, [64, 1, 64]);  mul_4 = None
        view_26: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_25, [64, 64]);  view_25 = None
        _to_copy_6: "f8e4m3fn[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_26, dtype = torch.float8_e4m3fn);  view_26 = None
        _to_copy_7: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten._to_copy.default(view_22, dtype = torch.float32);  view_22 = None
        clamp_3: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.clamp.default(_to_copy_7, 1e-12);  _to_copy_7 = None
        div_3: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.div.Tensor(clamp_3, 448.0);  clamp_3 = None
        t_6: "f32[1, 64][1, 1]cuda:0" = torch.ops.aten.t.default(div_3);  div_3 = None
        new_ones_2: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.new_ones.default(div_2, [1, 1], pin_memory = False)
        new_ones_3: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.new_ones.default(t_6, [1, 1], pin_memory = False)
        t_9: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.t.default(new_ones_3);  new_ones_3 = None
        _scaled_mm_1: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten._scaled_mm.default(_to_copy_6, t_8, new_ones_2, t_9, None, None, torch.bfloat16);  _to_copy_6 = t_8 = new_ones_2 = t_9 = None
        view_31: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(_scaled_mm_1, [64, 1, 64]);  _scaled_mm_1 = None
        view_32: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_31, [64, 1, 1, 64]);  view_31 = None
        slice_13: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(div_2, 0, 0, 9223372036854775807);  div_2 = None
        unsqueeze_13: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_13, 1);  slice_13 = None
        slice_14: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_13, 2, 0, 9223372036854775807);  unsqueeze_13 = None
        unsqueeze_14: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_14, 3);  slice_14 = None
        mul_6: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_32, unsqueeze_14);  view_32 = unsqueeze_14 = None
        view_33: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul_6, [64, 1, 64]);  mul_6 = None
        view_34: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_33, [64, 64]);  view_33 = None
        view_35: "f32[1, 64, 64][4096, 64, 1]cuda:0" = torch.ops.aten.view.default(view_34, [1, 64, 64]);  view_34 = None
        view_36: "f32[1, 64, 64, 1][4096, 64, 1, 1]cuda:0" = torch.ops.aten.view.default(view_35, [1, 64, 64, 1]);  view_35 = None
        slice_15: "f32[1, 64][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(t_6, 0, 0, 9223372036854775807);  t_6 = None
        unsqueeze_15: "f32[1, 1, 64][1, 64, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_15, 1);  slice_15 = None
        slice_16: "f32[1, 1, 64][1, 64, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_15, 2, 0, 9223372036854775807);  unsqueeze_15 = None
        unsqueeze_16: "f32[1, 1, 64, 1][1, 64, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_16, 3);  slice_16 = None
        mul_7: "f32[1, 64, 64, 1][4096, 64, 1, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_36, unsqueeze_16);  view_36 = unsqueeze_16 = None
        view_37: "f32[64, 64, 1][64, 1, 1]cuda:0" = torch.ops.aten.view.default(mul_7, [64, 64, 1]);  mul_7 = None
        view_38: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_37, [64, 64]);  view_37 = None
        _to_copy_9: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_38, dtype = torch.bfloat16);  view_38 = None
        t_10: "bf16[64, 64][1, 64]cuda:0" = torch.ops.aten.t.default(view_19)
        mm: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.mm.default(t_10, primals_1);  t_10 = primals_1 = None
        sum_1: "bf16[64][1]cuda:0" = torch.ops.aten.sum.dim_IntList(view_19, [0]);  view_19 = None
        return (_to_copy_9, mm, sum_1)

```

With the change, we save primals_2 for backward instead

```
 ===== Forward graph 0 =====
 /data/users/hirsheybar/checkout2/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "bf16[64, 64][64, 1]cuda:0", primals_2: "bf16[64, 64][64, 1]cuda:0", primals_3: "bf16[64][1]cuda:0"):
         # File: /data/users/hirsheybar/checkout2/pytorch/test/dynamo/test_repros.py:6943 in forward, code: out = Fp8LinearFn.apply(
        abs_1: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.abs.default(primals_1)
        view: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(abs_1, [64, 1, 64]);  abs_1 = None
        amax: "bf16[64, 1][1, 1]cuda:0" = torch.ops.aten.amax.default(view, [-1]);  view = None
        abs_2: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.abs.default(primals_2)
        view_1: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(abs_2, [64, 1, 64]);  abs_2 = None
        amax_1: "bf16[64, 1][1, 1]cuda:0" = torch.ops.aten.amax.default(view_1, [-1]);  view_1 = None
        _to_copy: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten._to_copy.default(amax, dtype = torch.float32);  amax = None
        clamp: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.clamp.default(_to_copy, 1e-12);  _to_copy = None
        div: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.div.Tensor(clamp, 448.0);  clamp = None
        reciprocal: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.reciprocal.default(div)
        view_2: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(primals_1, [64, 1, 64])
        view_3: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_2, [64, 1, 1, 64]);  view_2 = None
        slice_1: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(reciprocal, 0, 0, 9223372036854775807);  reciprocal = None
        unsqueeze: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_1, 1);  slice_1 = None
        slice_2: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze, 2, 0, 9223372036854775807);  unsqueeze = None
        unsqueeze_1: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_2, 3);  slice_2 = None
        mul: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_3, unsqueeze_1);  view_3 = unsqueeze_1 = None
        view_4: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul, [64, 1, 64]);  mul = None
        view_5: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_4, [64, 64]);  view_4 = None
        _to_copy_1: "f8e4m3fn[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_5, dtype = torch.float8_e4m3fn);  view_5 = None
        _to_copy_2: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten._to_copy.default(amax_1, dtype = torch.float32)
        clamp_1: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.clamp.default(_to_copy_2, 1e-12);  _to_copy_2 = None
        div_1: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.div.Tensor(clamp_1, 448.0);  clamp_1 = None
        reciprocal_1: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.reciprocal.default(div_1)
        view_6: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(primals_2, [64, 1, 64])
        view_7: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_6, [64, 1, 1, 64]);  view_6 = None
        slice_3: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(reciprocal_1, 0, 0, 9223372036854775807);  reciprocal_1 = None
        unsqueeze_2: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_3, 1);  slice_3 = None
        slice_4: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_2, 2, 0, 9223372036854775807);  unsqueeze_2 = None
        unsqueeze_3: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_4, 3);  slice_4 = None
        mul_1: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_7, unsqueeze_3);  view_7 = unsqueeze_3 = None
        view_8: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul_1, [64, 1, 64]);  mul_1 = None
        view_9: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_8, [64, 64]);  view_8 = None
        _to_copy_3: "f8e4m3fn[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_9, dtype = torch.float8_e4m3fn);  view_9 = None
        t: "f32[1, 64][1, 1]cuda:0" = torch.ops.aten.t.default(div_1);  div_1 = None
        new_ones: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.new_ones.default(div, [1, 1], pin_memory = False)
        new_ones_1: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.new_ones.default(t, [1, 1], pin_memory = False)
        t_2: "f8e4m3fn[64, 64][1, 64]cuda:0" = torch.ops.aten.t.default(_to_copy_3);  _to_copy_3 = None
        t_3: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.t.default(new_ones_1);  new_ones_1 = None
        _scaled_mm: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten._scaled_mm.default(_to_copy_1, t_2, new_ones, t_3, None, None, torch.bfloat16);  _to_copy_1 = t_2 = new_ones = t_3 = None
        view_10: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(_scaled_mm, [64, 1, 64]);  _scaled_mm = None
        view_11: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_10, [64, 1, 1, 64]);  view_10 = None
        slice_5: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(div, 0, 0, 9223372036854775807);  div = None
        unsqueeze_4: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_5, 1);  slice_5 = None
        slice_6: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_4, 2, 0, 9223372036854775807);  unsqueeze_4 = None
        unsqueeze_5: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_6, 3);  slice_6 = None
        mul_2: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_11, unsqueeze_5);  view_11 = unsqueeze_5 = None
        view_12: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul_2, [64, 1, 64]);  mul_2 = None
        view_13: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_12, [64, 64]);  view_12 = None
        view_14: "f32[1, 64, 64][4096, 64, 1]cuda:0" = torch.ops.aten.view.default(view_13, [1, 64, 64]);  view_13 = None
        view_15: "f32[1, 64, 64, 1][4096, 64, 1, 1]cuda:0" = torch.ops.aten.view.default(view_14, [1, 64, 64, 1]);  view_14 = None
        slice_7: "f32[1, 64][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(t, 0, 0, 9223372036854775807);  t = None
        unsqueeze_6: "f32[1, 1, 64][1, 64, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_7, 1);  slice_7 = None
        slice_8: "f32[1, 1, 64][1, 64, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_6, 2, 0, 9223372036854775807);  unsqueeze_6 = None
        unsqueeze_7: "f32[1, 1, 64, 1][1, 64, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_8, 3);  slice_8 = None
        mul_3: "f32[1, 64, 64, 1][4096, 64, 1, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_15, unsqueeze_7);  view_15 = unsqueeze_7 = None
        view_16: "f32[64, 64, 1][64, 1, 1]cuda:0" = torch.ops.aten.view.default(mul_3, [64, 64, 1]);  mul_3 = None
        view_17: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_16, [64, 64]);  view_16 = None
        _to_copy_4: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_17, dtype = torch.bfloat16);  view_17 = None
        add: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.add.Tensor(_to_copy_4, primals_3);  _to_copy_4 = primals_3 = None
        t_5: "bf16[1, 64][1, 1]cuda:0" = torch.ops.aten.t.default(amax_1);  amax_1 = None
        view_21: "bf16[1, 1, 64][1, 64, 1]cuda:0" = torch.ops.aten.view.default(t_5, [1, 1, 64]);  t_5 = None
        amax_3: "bf16[1, 1][1, 1]cuda:0" = torch.ops.aten.amax.default(view_21, [-1]);  view_21 = None
        unsqueeze_8: "bf16[1, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(amax_3, 1);  amax_3 = None

        # No stacktrace found for following nodes
        view_39: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(add, [64, 64]);  add = None
        return (view_39, primals_1, primals_2, unsqueeze_8)

INFO: TRACED GRAPH
 ===== Backward graph 0 =====
 <eval_with_key>.1 class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "bf16[64, 64][64, 1]cuda:0", primals_2: "bf16[64, 64][64, 1]cuda:0", unsqueeze_8: "bf16[1, 1, 1][1, 1, 1]cuda:0", tangents_1: "bf16[64, 64][64, 1]cuda:0"):
         # File: /data/users/hirsheybar/checkout2/pytorch/test/dynamo/test_repros.py:6946 in forward, code: out = out.unflatten(0, input.shape[:-1])
        view_19: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(tangents_1, [64, 64]);  tangents_1 = None

         # File: /data/users/hirsheybar/checkout2/pytorch/test/dynamo/test_repros.py:6943 in forward, code: out = Fp8LinearFn.apply(
        t_4: "bf16[64, 64][1, 64]cuda:0" = torch.ops.aten.t.default(primals_2);  primals_2 = None
        clone: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.clone.default(t_4, memory_format = torch.contiguous_format);  t_4 = None
        abs_3: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.abs.default(view_19)
        view_20: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(abs_3, [64, 1, 64]);  abs_3 = None
        amax_2: "bf16[64, 1][1, 1]cuda:0" = torch.ops.aten.amax.default(view_20, [-1]);  view_20 = None
        expand: "bf16[1, 64, 1][1, 0, 1]cuda:0" = torch.ops.aten.expand.default(unsqueeze_8, [1, 64, 1]);  unsqueeze_8 = None
        clone_1: "bf16[1, 64, 1][64, 1, 1]cuda:0" = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format);  expand = None
        view_22: "bf16[64, 1][1, 1]cuda:0" = torch.ops.aten.view.default(clone_1, [64, 1]);  clone_1 = None
        _to_copy_5: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten._to_copy.default(amax_2, dtype = torch.float32);  amax_2 = None
        clamp_2: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.clamp.default(_to_copy_5, 1e-12);  _to_copy_5 = None
        div_2: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.div.Tensor(clamp_2, 448.0);  clamp_2 = None
        reciprocal_2: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.reciprocal.default(div_2)
        view_23: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_19, [64, 1, 64])
        view_24: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_23, [64, 1, 1, 64]);  view_23 = None
        slice_9: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(reciprocal_2, 0, 0, 9223372036854775807);  reciprocal_2 = None
        unsqueeze_9: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_9, 1);  slice_9 = None
        slice_10: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_9, 2, 0, 9223372036854775807);  unsqueeze_9 = None
        unsqueeze_10: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_10, 3);  slice_10 = None
        mul_4: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_24, unsqueeze_10);  view_24 = unsqueeze_10 = None
        view_25: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul_4, [64, 1, 64]);  mul_4 = None
        view_26: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_25, [64, 64]);  view_25 = None
        _to_copy_6: "f8e4m3fn[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_26, dtype = torch.float8_e4m3fn);  view_26 = None
        _to_copy_7: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten._to_copy.default(view_22, dtype = torch.float32);  view_22 = None
        clamp_3: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.clamp.default(_to_copy_7, 1e-12);  _to_copy_7 = None
        div_3: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.div.Tensor(clamp_3, 448.0);  clamp_3 = None
        reciprocal_3: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.reciprocal.default(div_3)
        view_27: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(clone, [64, 1, 64]);  clone = None
        view_28: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_27, [64, 1, 1, 64]);  view_27 = None
        slice_11: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(reciprocal_3, 0, 0, 9223372036854775807);  reciprocal_3 = None
        unsqueeze_11: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_11, 1);  slice_11 = None
        slice_12: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_11, 2, 0, 9223372036854775807);  unsqueeze_11 = None
        unsqueeze_12: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_12, 3);  slice_12 = None
        mul_5: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_28, unsqueeze_12);  view_28 = unsqueeze_12 = None
        view_29: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul_5, [64, 1, 64]);  mul_5 = None
        view_30: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_29, [64, 64]);  view_29 = None
        _to_copy_8: "f8e4m3fn[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_30, dtype = torch.float8_e4m3fn);  view_30 = None
        t_6: "f32[1, 64][1, 1]cuda:0" = torch.ops.aten.t.default(div_3);  div_3 = None
        new_ones_2: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.new_ones.default(div_2, [1, 1], pin_memory = False)
        new_ones_3: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.new_ones.default(t_6, [1, 1], pin_memory = False)
        t_8: "f8e4m3fn[64, 64][1, 64]cuda:0" = torch.ops.aten.t.default(_to_copy_8);  _to_copy_8 = None
        t_9: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.t.default(new_ones_3);  new_ones_3 = None
        _scaled_mm_1: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten._scaled_mm.default(_to_copy_6, t_8, new_ones_2, t_9, None, None, torch.bfloat16);  _to_copy_6 = t_8 = new_ones_2 = t_9 = None
        view_31: "bf16[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(_scaled_mm_1, [64, 1, 64]);  _scaled_mm_1 = None
        view_32: "bf16[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.view.default(view_31, [64, 1, 1, 64]);  view_31 = None
        slice_13: "f32[64, 1][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(div_2, 0, 0, 9223372036854775807);  div_2 = None
        unsqueeze_13: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_13, 1);  slice_13 = None
        slice_14: "f32[64, 1, 1][1, 1, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_13, 2, 0, 9223372036854775807);  unsqueeze_13 = None
        unsqueeze_14: "f32[64, 1, 1, 1][1, 1, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_14, 3);  slice_14 = None
        mul_6: "f32[64, 1, 1, 64][64, 64, 64, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_32, unsqueeze_14);  view_32 = unsqueeze_14 = None
        view_33: "f32[64, 1, 64][64, 64, 1]cuda:0" = torch.ops.aten.view.default(mul_6, [64, 1, 64]);  mul_6 = None
        view_34: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_33, [64, 64]);  view_33 = None
        view_35: "f32[1, 64, 64][4096, 64, 1]cuda:0" = torch.ops.aten.view.default(view_34, [1, 64, 64]);  view_34 = None
        view_36: "f32[1, 64, 64, 1][4096, 64, 1, 1]cuda:0" = torch.ops.aten.view.default(view_35, [1, 64, 64, 1]);  view_35 = None
        slice_15: "f32[1, 64][1, 1]cuda:0" = torch.ops.aten.slice.Tensor(t_6, 0, 0, 9223372036854775807);  t_6 = None
        unsqueeze_15: "f32[1, 1, 64][1, 64, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_15, 1);  slice_15 = None
        slice_16: "f32[1, 1, 64][1, 64, 1]cuda:0" = torch.ops.aten.slice.Tensor(unsqueeze_15, 2, 0, 9223372036854775807);  unsqueeze_15 = None
        unsqueeze_16: "f32[1, 1, 64, 1][1, 64, 1, 1]cuda:0" = torch.ops.aten.unsqueeze.default(slice_16, 3);  slice_16 = None
        mul_7: "f32[1, 64, 64, 1][4096, 64, 1, 1]cuda:0" = torch.ops.aten.mul.Tensor(view_36, unsqueeze_16);  view_36 = unsqueeze_16 = None
        view_37: "f32[64, 64, 1][64, 1, 1]cuda:0" = torch.ops.aten.view.default(mul_7, [64, 64, 1]);  mul_7 = None
        view_38: "f32[64, 64][64, 1]cuda:0" = torch.ops.aten.view.default(view_37, [64, 64]);  view_37 = None
        _to_copy_9: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten._to_copy.default(view_38, dtype = torch.bfloat16);  view_38 = None
        t_10: "bf16[64, 64][1, 64]cuda:0" = torch.ops.aten.t.default(view_19)
        mm: "bf16[64, 64][64, 1]cuda:0" = torch.ops.aten.mm.default(t_10, primals_1);  t_10 = primals_1 = None
        sum_1: "bf16[64][1]cuda:0" = torch.ops.aten.sum.dim_IntList(view_19, [0]);  view_19 = None
        return (_to_copy_9, mm, sum_1)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148922
Approved by: https://github.com/zou3519
2025-03-18 20:08:11 +00:00
Aaron Gokaslan
a0ac63cbd9 [BE]: Apply ruff PERF403 to use dict comprehensions more often (#149257)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149257
Approved by: https://github.com/jansel
2025-03-18 00:46:07 +00:00
PyTorch MergeBot
24cfeec2c7 Revert "[BE]: Apply ruff PERF403 to use dict comprehensions more often (#149257)"
This reverts commit bfee141666.

Reverted https://github.com/pytorch/pytorch/pull/149257 on behalf of https://github.com/malfet due to Let's see if it helps restore compiler benchmark sanity, see 8bc7bd94a5/1 ([comment](https://github.com/pytorch/pytorch/pull/149257#issuecomment-2731133812))
2025-03-17 22:57:00 +00:00
Aaron Gokaslan
bfee141666 [BE]: Apply ruff PERF403 to use dict comprehensions more often (#149257)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149257
Approved by: https://github.com/jansel
2025-03-16 23:52:58 +00:00
Brian Hirsh
3646d4dbc8 [partitioner] always ban compiler-driven recompute of collectives by default (#147561)
This should fix the hang in https://fb.workplace.com/groups/1075192433118967/permalink/1603268720311333/

The argument here is that:

(1) in general, it is not safe for the partitioner to sometimes choose to recompute collectives in the backward. Why? If we are running a distributed job, where many ranks are compiling at the same time, we need every rank to make a consistent decision about which collectives are recomputed for backward. If we let each compiler instance make its own choice without any cross-rank communication, they can make different choices and cause NCCL hangs (see the link above)

(2) later on, we'll want an `spmd_mode` flag that causes the compiler to issue collectives and communicate info across ranks. Once we have such a config, then turning it on should make it safe for the partitioner to potentially choose to recompute collectives (and agree on the binary "recompute-or-save" choice across all ranks)

(3) even without an `spmd_mode`, users can override this choice by using `torch.utils.checkpoint()` in their user code. User checkpointing generally always overrides the partitioner, and this should be safe because we expect the user to apply checkpointing consistently across ranks

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147561
Approved by: https://github.com/zou3519
2025-03-13 03:36:13 +00:00
Brian Hirsh
621dadd4ca partitioner: when materializing unbacked tensor intermediates, apply hint to symbol, not expr (#144097)
Fixes https://github.com/pytorch/pytorch/issues/144095

open to suggestions: the `hint_int(..., fallback=...)` API feels like a bit of a footgun, because:

(1) we use the same guess for every unbacked symint (both symbols, and compound expressions)
(2) the user may have established some relationship between some unbacked symints that we are not taking into account.

I'm not sure how real of an issue (2) is - is it common to e.g. generate two unbacked symints, and then add a runtime assert that they are unequal?

Instead I did something simpler that's just enough to fix the linked issue: if we have a sympy expression containing an unbacked symbol (e.g. `u0 + 1`), then the partitioner will now fill in the symbol with our guess instead of the expression (plugging in `u0=4096` gets us 4097). This was important for an internal custom op, that had some logic like this:
```
def custom_op(x: [u0], y: [u0 + 1]):
    assert x.shape[0] = y.shape[0] - 1
    ...
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144097
Approved by: https://github.com/laithsakka
2025-03-11 02:11:57 +00:00
Luca Wehrstedt
f80aad62fa Improve Pareto frontier plot for AutoAC (#148678)
This was added in https://github.com/pytorch/pytorch/pull/126320. It's a very nice feature, which can be used to predict memory usage for different budget values.

However, it had some limitations, notably in terms of resolution (it only sampled 21 points across the whole range thus missed many threshold values) and in distributed settings.

Here I fix those by using recursive binary searches to identify all thresholds (up to a resolution of 1e-3, which can be made configurable) and output them in SVG (to be able to discern different points), plus I add the rank to the filename and store it in a user-define directory.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148678
Approved by: https://github.com/Chillee, https://github.com/fmassa
2025-03-07 13:22:29 +00:00
eellison
481a57bc37 Support torch.compile rng selective activation checkpointing with cudagraph (#146878)
TODO:
- [x]  Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync
- [x] Update rng state initialization to take from correct device
- [x]  Tests
- [x] handling of retain_graph
- [x] respect fallback random

Fix for https://github.com/pytorch/pytorch/issues/130123.

Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states.

We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward.

```
 ===== Forward graph 1 =====
 /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0);  fwd_rng_state_0 = None
        ...

 ===== Backward graph 1 =====
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0);  bwd_rng_state_0 = None
```

There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls:
- fwd0: fwd_rng_state0 -> fwd_rng_state1
- fwd1: fwd_rng_state1 -> fwd_rng_state2
- bwd1
- bwd0

Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary.

Other notes:

Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order.

Questions for reviewers:

This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`.

Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set

I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts.

Edit: updated to be taken from randint()

Update: initializing rng states from torch.randint..

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
2025-02-28 00:47:03 +00:00
PyTorch MergeBot
17358ce778 Revert "Support torch.compile rng selective activation checkpointing with cudagraph (#146878)"
This reverts commit ad0c879e22.

Reverted https://github.com/pytorch/pytorch/pull/146878 on behalf of https://github.com/wdvr due to lint failure ([comment](https://github.com/pytorch/pytorch/pull/146878#issuecomment-2686767956))
2025-02-27 03:36:16 +00:00
eellison
ad0c879e22 Support torch.compile rng selective activation checkpointing with cudagraph (#146878)
TODO:
- [x]  Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync
- [x] Update rng state initialization to take from correct device
- [x]  Tests
- [x] handling of retain_graph
- [x] respect fallback random

Fix for https://github.com/pytorch/pytorch/issues/130123.

Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states.

We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward.

```
 ===== Forward graph 1 =====
 /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0);  fwd_rng_state_0 = None
        ...

 ===== Backward graph 1 =====
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0);  bwd_rng_state_0 = None
```

There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls:
- fwd0: fwd_rng_state0 -> fwd_rng_state1
- fwd1: fwd_rng_state1 -> fwd_rng_state2
- bwd1
- bwd0

Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary.

Other notes:

Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order.

Questions for reviewers:

This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`.

Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set

I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts.

Edit: updated to be taken from randint()

Update: initializing rng states from torch.randint..

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
2025-02-27 02:08:29 +00:00
Brian Hirsh
89b9c12de8 remove prints from partitioner (#147749)
See c57894cd74..22d8f9a657 (r1968015955)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147749
Approved by: https://github.com/Skylion007, https://github.com/laithsakka
2025-02-24 21:03:45 +00:00
Laith Sakka
454fbd5bbe realize stride symbols in estimate_runtime (#146752)
Unfortuanlty could not create a local repo, or unit test.
fix https://github.com/pytorch/pytorch/issues/146686

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146752
Approved by: https://github.com/bobrenjc93, https://github.com/bdhirsh
2025-02-19 06:02:49 +00:00
Basil Wong
05001f0459 Add Structured Tracing for Traced Graph Edge Details for AC Debugging (#146634)
Summary:
Updating the structured trace infrastructure so that we are able to output to Zoomer and have an E2E solution.

Context Doc: https://docs.google.com/document/d/1T6omIBEWVhbOiwDLSLffgQwjxiT2rQv8QvvQwXkw4fY/edit?usp=sharing

Test Plan:
### Testing Structured Log + tlparse locally

Command:
```
TORCH_TRACE=/data/users/basilwong/fbsource/fbcode/log_torch_trace buck2 run mode/opt //aps_models/ads/icvr:icvr_launcher -- mode=local_fb_fm_v4 launcher.num_workers=2
```

Torch Trace Logs (local then sent to paste): P1686419449
```
cat log_torch_trace/dedicated_log_torch_trace_rank_0_2lg012xo.log | pastry
P1686419449
```

tlparse output: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpyiv5wj/rank_1/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=100

tlparse graph edge details output: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpyiv5wj/rank_1/9_0_0/joint_graph_information_397.txt?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=100

Differential Revision: D61557220

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146634
Approved by: https://github.com/jansel, https://github.com/Yuzhen11
2025-02-14 02:04:26 +00:00
Aaron Gokaslan
292af3cc89 [BE][Ez]: ISC001 Auto concatenate implicit one line strings (#146408)
Apply ruff rule about implicit string concatenation, this autofixes strings that are all the same type and on the same line. These lines are broken up likely as the result of autoformatters in the past. All fixes are automated using the autofixes in ISC001.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146408
Approved by: https://github.com/justinchuby, https://github.com/janeyx99
2025-02-04 19:07:04 +00:00
Sam Larsen
23fffb54d5 Use OrderedSet in _functorch/partitioners (#146102)
In an attempt to make partitioning more deterministic, change all sets in partitioners.py to OrderedSets. Note that this change does not fix the non-determinism we're seeing in the internal model. But let's at least eliminate this potential source of non-determinism before investigating any changes to the mincut approach?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146102
Approved by: https://github.com/oulgen
2025-02-04 17:43:07 +00:00
Brian Hirsh
ed141d7d1a dont assign a size to _assert_scalar in partitioner (#143877)
Fixes https://github.com/pytorch/pytorch/issues/143876

Open to other suggestions - we have an invariant that all nodes in our ATen graphs should have a `meta['val']` field, but I don't think this is actually true in all cases, so I just hardcoded the invariant to ignore `_assert_scalar()` (which is a "special" op used in dynamic shapes for runtime asserts, and doesn't have a meta['val'] field)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143877
Approved by: https://github.com/zou3519
2025-01-29 16:21:37 +00:00
Brian Hirsh
7ca156f0ee partitioner: avoid inserting duplicates into heap (#145082)
Fixes https://github.com/pytorch/pytorch/issues/145081

This looks like it was a source of quadratic compile times in the torchtitan CP graphs. There's some code in the partitioner that iteratively adds users of a node to a heap, and pops the earliest user. If you have long parallel chains of fusible ops that all eventually feed into some shared ops, then this can result in:
(1) a node getting added to the heap many times
(2) each time we pop that node, we add (duplicates of) each of that node users to the heap
(3) repeat with each user

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145082
Approved by: https://github.com/xmfan
2025-01-28 23:44:45 +00:00
Aaron Gokaslan
5ebca3015d [BE]: Simplify set add with set update (#145152)
Simplifies the set update slightly to be more readable and efficient.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145152
Approved by: https://github.com/XuehaiPan, https://github.com/albanD

Co-authored-by: Xuehai Pan <XuehaiPan@outlook.com>
2025-01-23 20:18:13 +00:00
Li Yu (ads)
e6a84be3d3 [PyTorch] Add backend aot_eager_decomp_partition_with_mode (#143250)
Summary:
## Why
To make it possible to run torch dispatch mode inside compiled modules. This is to enable running MemoryTrackerMode (in next diff) to collect memory usage of compiled modules.

## What
Add a backend aot_eager_decomp_partition_with_mode.
Add an enable_log to the backend to control the compilation logging (which can be very verbose and slow the run of mode)

Test Plan:
unittest

E2e tested in the next diff which shows the memory read from the mode passed to this backend is very close to the actual job's memory snapshot.

Differential Revision: D67227144

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143250
Approved by: https://github.com/bdhirsh
2025-01-22 23:20:59 +00:00
PyTorch MergeBot
6e53588789 Revert "[BE]: Simplify set add with set update (#145152)"
This reverts commit 0cb9b2284a.

Reverted https://github.com/pytorch/pytorch/pull/145152 on behalf of https://github.com/davidberard98 due to land race with https://github.com/pytorch/pytorch/pull/145165 broke lint ([comment](https://github.com/pytorch/pytorch/pull/145152#issuecomment-2608378172))
2025-01-22 22:14:26 +00:00
Aaron Gokaslan
0cb9b2284a [BE]: Simplify set add with set update (#145152)
Simplifies the set update slightly to be more readable and efficient.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145152
Approved by: https://github.com/XuehaiPan, https://github.com/albanD

Co-authored-by: Xuehai Pan <XuehaiPan@outlook.com>
2025-01-22 21:31:13 +00:00
Aaron Orenstein
78bff1e8c1 PEP585 update - torch/_functorch (#145139)
See #145101 for details.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145139
Approved by: https://github.com/bobrenjc93
2025-01-19 07:06:10 +00:00
Basil Wong
7b2af25f80 [1/n] Support Dynamic Memory Budget in Auto AC (#143539)
# Summary:
Full Context: https://docs.google.com/document/d/1-j5KSbfGFJQcH4sYh7BIeJXso3zYzl5G5yFQqXdKx_o/edit?usp=sharing

tl;dr

This change introduces classes which help determine a dynamic memory budget. This will mostly be helpful for models with many implicit graph breaks.

---

New Classes:

*GraphInfoProvider*
* Takes the joint_graph as well as the input memories and runtimes and parses the graph + values into usable forms for the SolverEvaluator.

*KnapsackEvaluator*
* Provides a function: Given all of the four inputs (solver function as a callable, max_dynamic_memory_budget, min_dynamic_memory_budget, dynamic_memory_budget_pareto_granularity) it returns an approximation of the knee point of the pareto distribution.

# Test Plan:

### LintRunner

LintRunner Output: P1700445547

### Unit Tests

```
$ buck test @mode/opt //caffe2/test/functorch:test_ac_knapsack
`@mode/opt` was specified, but not found. Using file at `//mode/opt`.
This behavior is being deprecated. Please use `"@//mode/opt"` instead
File changed: fbcode//caffe2/.ruff_cache/0.7.4/.tmpB6PmDS
File changed: fbsource//xplat/caffe2/test/functorch/test_ac_knapsack.py
File changed: fbcode//caffe2/.ruff_cache/0.7.4/.tmpyjCiPn
20 additional file change events
Buck UI: https://www.internalfb.com/buck2/414ead46-9ede-4192-8e1a-5d3c52bdb9cc
Test UI: https://www.internalfb.com/intern/testinfra/testrun/6473924710342830
Network: Up: 0B  Down: 0B  (reSessionID-159794b9-9d61-477e-8e63-9bdeaa537dca)
Analyzing targets. Remaining     0/214
Executing actions. Remaining     0/6933                                                                                                                                                                                  0.1s exec time total
Command: test.     Finished 1 local
Time elapsed: 18.5s
Tests finished: Pass 15. Fail 0. Fatal 0. Skip 0. Build failure 0
```

### Test Run

Updated the config:

```
      activation_memory_budget_solver: DYNAMIC_MEMORY_BUDGET_DP
```

Confirming proper execution via: [aps-fb_fm_v4_768_01_dynamic-2a792ba8af](https://www.internalfb.com/mlhub/pipelines/runs/mast/aps-fb_fm_v4_768_01_dynamic-2a792ba8af?job_attempt=0&version=0&env=PRODUCTION)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143539
Approved by: https://github.com/jansel
2024-12-21 07:38:52 +00:00
Tom Ritchford
dc23f1944a Remove unused Python variables in torch/[_-a]* (#133492)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133492
Approved by: https://github.com/albanD
2024-12-12 17:39:14 +00:00
PyTorch MergeBot
5c97ac9721 Revert "Remove unused Python variables in torch/[_-a]* (#133492)"
This reverts commit fda975a7b3.

Reverted https://github.com/pytorch/pytorch/pull/133492 on behalf of https://github.com/clee2000 due to Sorry, I need to revert this in order to revert something else.  The only thing you need to do is rebase and remerge ([comment](https://github.com/pytorch/pytorch/pull/133492#issuecomment-2536635516))
2024-12-11 17:29:12 +00:00
Tom Ritchford
fda975a7b3 Remove unused Python variables in torch/[_-a]* (#133492)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133492
Approved by: https://github.com/albanD
2024-12-10 21:48:44 +00:00
Basil Wong
2d9b081012 Move knapsack algorithms into separate filing (#141451) (#141614)
Summary:

Original Context Doc: https://docs.google.com/document/d/1Gv5ZqN3UY7kTuCUd0JLU3L_onwVIr2vd3AiZyim1Muc/edit?disco=AAABZX_tqdk

### Changes

This diff restructures the Partitioners.py file in the _functorch package.

* Moves the three knapsack problem algorithems (greedy, ilp, dp) into a separate file

Test Plan:
### Unit Testing

```
$ buck test mode/opt //caffe2/test/functorch:test_ac
File changed: fbsource//xplat/caffe2/test/functorch/TARGETS
File changed: fbsource//xplat/caffe2/test/functorch
File changed: fbsource//xplat/caffe2/test
7 additional file change events
Soft Error: source_directory_includes_subpackage: Directory `v2.17.1-1` of package `fbsource//third-party/nccl` may not cover any subpackages, but includes subpackage `v2.17.1-1/src/tests`.
Soft Error: source_directory_includes_subpackage: Directory `v2.18.3-1` of package `fbsource//third-party/nccl` may not cover any subpackages, but includes subpackage `v2.18.3-1/src/tests`.
Soft Error: source_directory_includes_subpackage: Directory `v2.19.3-1` of package `fbsource//third-party/nccl` may not cover any subpackages, but includes subpackage `v2.19.3-1/src/tests`.
Buck UI: https://www.internalfb.com/buck2/a2f91f8a-5326-435e-9075-5af0de930b8b
Test UI: https://www.internalfb.com/intern/testinfra/testrun/7036874660108924
Network: Up: 28MiB  Down: 3.5GiB  (reSessionID-19af3d71-7528-448c-9126-5615d27b3bd7)
Jobs completed: 423656. Time elapsed: 3:57.2s.
Cache hits: 99%. Commands: 146147 (cached: 145758, remote: 317, local: 72)
Tests finished: Pass 8. Fail 0. Fatal 0. Skip 0. Build failure 0

```

### Tested on Local Training Run

```
CUDA_VISIBLE_DEVICES=5,6 AOT_PARTITIONER_DEBUG=1 PARTITIONER_MEMORY_BUDGET_PARETO=0 buck2 run mode/opt //aps_models/ads/icvr:icvr_launcher -- mode=local_fb_fm_v4 launcher.num_workers=2 2>&1 | tee log_2024-11-2320:41:50.757256__bento_trigger.txt
```

Output Summary Paste: P1685697066

Differential Revision: D65800097

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141614
Approved by: https://github.com/jansel, https://github.com/Chillee
2024-12-07 03:21:52 +00:00
Francisco Massa
e5d02e0cfb Fix non-determinism in the partitioner (#141682)
When multiple nodes have similar sizes and are part of the `banned_nodes` (which is a `set` and not a `list`), there is non-determinism present in the partitioner due to sorting only by node-size.

This PR fixes this by also sorting by node name.

It would be good to add some tests, but I'm not sure about the best way to do it here.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141682
Approved by: https://github.com/Chillee, https://github.com/yf225
2024-11-27 19:33:15 +00:00
Basil Wong
00c829876c Log Full Knapsack Problem Information (#140757)
Summary: When AOT_PARTITIONER_DEBUG is set to 1 and debug logging is turned on we can now log the full input and output for each knapsack problem.

Differential Revision: D65633086

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140757
Approved by: https://github.com/jansel
2024-11-18 20:36:32 +00:00
Basil Wong
2f1dbfea02 Logging Refactor - Remove Print Statements (#139782)
Summary:
Removes print statements and implements logging via the logging library.

Hopefully this will allow more control on the level of logging when running models.

Test Plan:
```
AOT_PARTITIONER_DEBUG=1 buck2 run @mode/opt //aps_models/ads/icvr:icvr_launcher -- mode=local_fb_fm_v4 launcher.num_workers=2
```

Resulting output paste: P1674535630
* Full logs paste: P1674535621

```
pastry P1674535621 | grep "functorch/partitioners.py" | pastry
```

Logging results: P1674549514

Differential Revision: D61678215

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139782
Approved by: https://github.com/paryxyt, https://github.com/jansel
2024-11-13 23:09:18 +00:00
PyTorch MergeBot
6dada2136a Revert "Refactor FxGraphDrawer to use HTML-like labels (#137726)"
This reverts commit 1e73842029.

Reverted https://github.com/pytorch/pytorch/pull/137726 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but it looks like some internal components are failing after this change and need to be updated ([comment](https://github.com/pytorch/pytorch/pull/137726#issuecomment-2455332612))
2024-11-04 17:44:44 +00:00
Gabriel Ferns
1e73842029 Refactor FxGraphDrawer to use HTML-like labels (#137726)
Fixes https://github.com/pytorch/pytorch/issues/137499
Testing: Added a new unit test to make sure that the regression case succeeds.
I'm debating about whether to make the borders visible. I'm partial to no borders, but it might make it harder for some people to read?
![68a2b0e3-orig_fx_graph_diagram](https://github.com/user-attachments/assets/fbc2fd98-9e76-488e-8ebe-c64fbf206932)
Vs.
![2bfe1c4f-orig_fx_graph_diagram](https://github.com/user-attachments/assets/b6bc88ba-dda2-4cf7-84ac-a615e1e03a74)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137726
Approved by: https://github.com/eellison, https://github.com/malfet
2024-11-01 23:19:50 +00:00
Parikshit Shah
853da168fc [AC] Backward Pass Aware AC - adding hooks to partitioner to pass callable (#137785)
Summary: same as title. Plan is to pass a callable to the partitioner to perform custom autoAC via an ILP. This is the same as a previous diff D63714905 which was landed and then subsequently reverted by PyTorch Release Engineering because of a failing unit test (f7b8d36c28). We think the unit test is buggy, and we also fix the same.

Test Plan: tbd

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137785
Approved by: https://github.com/basilwong

Co-authored-by: Huy Do <huydhn@gmail.com>
2024-10-21 21:45:13 +00:00
PyTorch MergeBot
9bb327bfc6 Revert "[AC] Backward Pass Aware AC - adding hooks to partitioner to pass callable (#137785)"
This reverts commit a8b912f39d.

Reverted https://github.com/pytorch/pytorch/pull/137785 on behalf of https://github.com/ezyang due to breaks lint ([comment](https://github.com/pytorch/pytorch/pull/137785#issuecomment-2427295668))
2024-10-21 17:18:56 +00:00
Parikshit Shah
a8b912f39d [AC] Backward Pass Aware AC - adding hooks to partitioner to pass callable (#137785)
Summary: same as title. Plan is to pass a callable to the partitioner to perform custom autoAC via an ILP. This is the same as a previous diff D63714905 which was landed and then subsequently reverted by PyTorch Release Engineering because of a failing unit test (f7b8d36c28). We think the unit test is buggy, and we also fix the same.

Test Plan: tbd

Differential Revision: D64246495

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137785
Approved by: https://github.com/basilwong
2024-10-21 15:30:07 +00:00
Tom Ritchford
15722debfb Remove two unused variables in _functorch/partitioners.py (#137998)
* Extracted from https://github.com/pytorch/pytorch/pull/133492

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137998
Approved by: https://github.com/Skylion007
2024-10-16 10:58:31 +00:00
PyTorch MergeBot
2fff990c16 Revert "[AutoAC] Backward Pass Aware AC - changes to partitioner to acommodate SOLVER as a callable (#137314)"
This reverts commit 932b9945c0.

Reverted https://github.com/pytorch/pytorch/pull/137314 on behalf of https://github.com/huydhn due to The failure shows up in trunk ([comment](https://github.com/pytorch/pytorch/pull/137314#issuecomment-2401311719))
2024-10-09 04:53:30 +00:00
Parikshit Shah
932b9945c0 [AutoAC] Backward Pass Aware AC - changes to partitioner to acommodate SOLVER as a callable (#137314)
Summary: making it so that the config can pass `config.activation_memory_budget_solver` as a callable method and then that callable is invoked to determine the set of saved/recomputed nodes.

Test Plan: tbd

Reviewed By: Chillee, basilwong

Differential Revision: D63714905

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137314
Approved by: https://github.com/eellison, https://github.com/basilwong

Co-authored-by: Parikshit Shah <parikshit@meta.com>
2024-10-09 00:39:29 +00:00
Brian Hirsh
bf73af4b4e dont let partitioner think it can fuse pointwise ops into user triton kernels (#136878)
Previously if we had a graph like:
```
        triton_kernel_wrapper_functional_proxy = triton_kernel_wrapper_functional(...)
        getitem: "f32[3][1]cuda:0" = triton_kernel_wrapper_functional_proxy['out_ptr']
        getitem_1: "f32[3][1]cuda:0" = triton_kernel_wrapper_functional_proxy['out2_ptr']
        sigmoid: "f32[3][1]cuda:0" = torch.ops.aten.sigmoid.default(getitem_1)
        mul: "f32[3][1]cuda:0" = torch.ops.aten.mul.Tensor(tangents_1, sigmoid)
```

The partitioner would assume that the `sigmoid()` could be fused into either its user (the pointwise mul), or its producer (the user triton kernel). This could lead to a bad partitioning:

(1) If the partitioner thinks we can fuse the sigmoid with its producer triton kernel, we would keep the sigmoid compute in the forward, and have to generate two separate kernels in the forward (user triton kernel, dedicated sigmoid kernel)

(2) if the partitioner puts the sigmoid in the backward instead, we could fuse it with an existing backward kernel (the mul with a tangent)

Reviewed By: embg

Differential Revision: D63551393

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136878
Approved by: https://github.com/zou3519
2024-10-02 13:52:44 +00:00
Will Feng
a815611db9 [Traceable FSDP2][Partitioner] Must save AC output if output has a backward hook (#135727)
If node is AC region output and has a backward hook on it, we intentionally choose to save it.
This is to work around circular dependencies in Traceable FSDP2+AC.
Example:
```
out = fully_shard(utils.checkpoint(module))(x)
norm_out = layer_norm(out)
```
and there is a circular dependency:
1. In backward, grad_input of layer_norm aka. `out_grad` is actually dependent on `out`.
2. `out` depends on `out`'s backward hook created by FSDP2 (which does all-gather for `module` weights) in order to be recomputed.
3. `out`'s FSDP2 backward hook, as is the case for all eager backward hooks, depends on `out_grad`  -> circular dependency with (1)!

Solution: check whether `out` has a backward hook, and if so, intentionally save `out` in forward graph outputs. With this, we can break the above circular dependency.

----

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135727
Approved by: https://github.com/Chillee
2024-09-14 08:45:58 +00:00
Will Feng
94d2471d1f [Traceable FSDP2] Use .copy_ instead of .set_ for unsharded_param inplace update; Replace unsharded_param graph input usage with graph intermediate; Support FSDP2+LoRA (#133730)
Using `fsdp.set_` for unsharded_param inplace update causes difficult-to-debug errors when enabling Traceable FSDP2 on TorchTune models. In this PR, we change it to use `fsdp.copy_` which fixes the error and also strictly follows eager semantics (i.e. if user explictly stores an alias of the unsharded_param during execution of the user's module code, that alias will get updated correctly when the unsharded_param is copy_ into; whereas if we just swap out unsharded_param storage via set_, that user-saved alias will not get updated, which is not good).

This PR also implements the graph pass to remove the resizes and copy if there is a resize_(full) -> copy_ -> resize_(0) pattern.

------

Test commands:
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_trace_fsdp_copy_`
- `pytest -rA test/dynamo/test_repros.py::ReproTests::test_partitioner_cse_respects_mutation_boundaries`
- `pytest -rA test/dynamo/test_repros.py::ReproTests::test_fsdp_set_input_mutation_applied_when_input_gets_no_gradients`
- `pytest -rA test/inductor/test_pattern_matcher.py::TestPatternMatcher::test_mutation_op_matching`
- `python test/inductor/test_distributed_patterns.py DistributedPatternTests.test_fake_distributed_aot_eager`
- `PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=1 PYTORCH_TEST_WITH_CROSSREF=1 python test/functorch/test_aotdispatch.py TestEagerFusionOpInfoCPU.test_aot_autograd_exhaustive_norm_cpu_float32`
- `python test/distributed/test_inductor_collectives.py TestCollectivesInductor.test_backwards`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133730
Approved by: https://github.com/bdhirsh
2024-09-11 23:01:05 +00:00
Laith Sakka
c8ab9b06a2 Redesign custom op functionlaization for better re-inplace (#134409)
- The new implementation (auto_functionalized_v2) is enabled by default but can be disable
 using an inductor flag.
- In export mode the old implementation is used.

**Motiviation**
Previous functionalization fails to re-inplace arguments when they are view over other tensors.
see issue https://github.com/pytorch/pytorch/issues/131192
The new functionalization is easier to re-inplace for views.

**A) Functionalizations pass**
consider a program:

```

func(t)
    x = t[0]
    y = t[1]
    foo(x, y) # custom operator with x, y mutable
    return (x, y, t)
```

- To functionalize `foo` we generate a function that operates on the base tensors of the inputs;  (x.base() and y.base())
and record how to regenerates the views out of the base for argument x by recording ```ViewInfo=(x.base(), x.size(), x.stride, x,storage_offset())```

- Due to some limitations on the torch.export arguments format, we have to generate alot of arguments, but this is something we can simplify in the future, for the example above we get the following function.

   ```
   auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.mylib.foo.default,
     _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0 ,
     _y_base_index = 0,_y_size = (), _y_stride = (), _y_storage_offset = 1   ,
     _all_bases = [arg0_1])
   ```
 -  In the code above:
        - _all_bases[t]: refers to a unique set of bases for all foo arguments.
        - for each argument x we have _x_base_index, _x_size, _x_stride, _x_storage_offset that can be used to (1)  regenerate x from _all_bases[_x_base_index] or a copy of a the base.

-  the output of auto_functionalized is foo output , followed by x tensors one for each base in  _all_bases, that is a copy of the base tensor after observing the mutations of the all the arguments that are views of that base.

-  for each use of a base in _all_bases or a view of it , that are after the call to foo, replace it with a view of the new output

 for the function above after functionalization we get :
 ```
    def forward(self, arg0_1: "f32[2][1]cpu"):
        auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg0_1])
        getitem_1: "f32[2][1]cpu" = auto_functionalized[1];  auto_functionalized = None
        copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1);  arg0_1 = copy_ = None

        # No stacktrace found for following nodes
        select_2: "f32[][]cpu" = torch.ops.aten.select.int(getitem_1, 0, 0)
        select_3: "f32[][]cpu" = torch.ops.aten.select.int(getitem_1, 0, 1);  getitem_1 = None
        return (select_2, select_3)
```

**B) Semantics of  auto_functionalize**
The new semantics of auto_functionalize is as the following:
1. For each base in all_bases, copy the base and create all_bases copies. (if a base is inplaced we do not need to copy it)
2. For each arg, regenerate the arg from the copy of its base using the view information above.
3. return the original foo output followed by the new bases.

**C) Re-inplace pass**
since auto_functionalize not copy the bases, what we actually inplace is the bases.
 (run just like before but on the beses instead of args).

1. For each base b in _all_bases check if there is any use of base (or its aliases/views) after auto_functionalize (before its overwritten with a copy) if there is not any, then inplace it (avoid copying it in step 1 above).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134409
Approved by: https://github.com/zou3519
2024-09-04 17:08:58 +00:00
rzou
c582602245 Update partitioner's is_fusible heuristic to respect triton kernels (#134491)
mutated arguments to triton kernels are fusible into the triton kernel.

Test Plan:
- new test

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134491
Approved by: https://github.com/Chillee
ghstack dependencies: #134364, #134466, #134490
2024-08-27 15:57:32 +00:00
rzou
c7cbcdad76 Update partitioner's is_fusible heuristic to respect auto_functionalized (#134490)
We say Node a is fusible into node b if node b is an auto_functionalized
node that may reinplace node a later on.

This PR also changes aten.empty to be recomputable w.r.t the Partitioner
(it is, like aten.zeros, cheap to recompute and fusible into other ops).

Fixes https://github.com/pytorch/pytorch/issues/134468

Test Plan:
- new test

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134490
Approved by: https://github.com/Chillee
ghstack dependencies: #134364, #134466
2024-08-27 13:05:01 +00:00
IvanKobzarev
8ae4f82243 [aotd] Support HOP effects in backward (#132638)
Support of effectful operations in backward:

1/ AOTD collects metadata from forward fn only, so we can have usage of effectful ops in backward, that were not used in forward => Allowing tokens discovery during joint function .

FunctionalTensorMode holds _tokens, in Joint function after tracing forward we memoize _tokens as `_tokens_forward_output`.

2/ Tokens are added as primals inputs (forward) in EffectTokensWrapper.
Tokens that will be used in backward are in partitioner saved values. We do not have control on which positions they are saved in forward outputs.

2/ If new tokens discovered in backward after tracing joint_fn, the result graph will be manually added in the end of primals.
_aot_autograd/utils.py

3/ All effectful ops during backward are marked with 'must_be_in_backward' partitioner_tag, to prevent partiitoner to place them in forward.

For that functional_tensor_mode got new optional state `self._effects_partitioner_tag` for effectful ops, to set after tracing forward.

There are additional changes in partitioner to improve functionality of 'must_be_in_backward'

4/ Unlift tokens now should run for both forward and backward.
- As saved for backward tokens are placed on non static places - we identify input and output tokens to erase, by input and output of `with_effects` operation
- In forward we can have input tokens, discovered in backward, that are not used in with_effects ops in forward, but saved for backward. We identify them by position in forward inputs.

5/ Adding aot debug logging for graphs before unlifting and before adding additional primal for backward tokens.

Tests:
```
python test/higher_order_ops/test_with_effects.py
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132638
Approved by: https://github.com/bdhirsh
2024-08-23 15:30:58 +00:00
Aaron Orenstein
d95aedf5fd [BE] typing for decorators - fx/_compatibility (part 1) (#134202)
Part of #134054.

This corresponds to the pytorch mypy changes from D61493706. Updating takes so
long and touches so many files that it's impossible to land as a whole without conflicting with some other intermediate change.
So landing these 'type: ignore' for pytorch in advance of them actually being needed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134202
Approved by: https://github.com/Skylion007
2024-08-22 17:07:33 +00:00
IvanKobzarev
57625bacea [partitioner] Fix must_be_in_backward corner cases (#134002)
Preparation PR for https://github.com/pytorch/pytorch/pull/132638

"must_be_in_backward" fails the partitioner, if partitioner picks this node as saved_values.

The fix is to prevent partitioner to pick those nodes during nodes classification.

It's hard to make a test without making effectful ops in backward "must_be_in_backward", which will be testing this ( https://github.com/pytorch/pytorch/pull/132638 )

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134002
Approved by: https://github.com/bdhirsh
ghstack dependencies: #134003
2024-08-21 15:58:49 +00:00
Nicolas Macchioni
5cb05a82b4 [BC breaking] move benchmarking + prefer inductor path (#132827)
move benchmarking out of `torch._inductor.runtime.runtime_utils` and into `torch._inductor.runtime.benchmarking`, and prefer this path over directly accessing Triton's benchmarking

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132827
Approved by: https://github.com/eellison
2024-08-08 00:47:45 +00:00
Shuqi Yang
a74e5abda4 Fix issues in activation_memory_budget for float8 (#132687)
Summary:
When using activation_memory_budget for float8 training, two issues were noticed:

- When `aggressive_options` (https://fburl.com/code/m1yoskxw) is called , all fp8 gemms (the scaled_mm op) are saved for recomputation.
- After adding "scaled_mm" in the `compute_intensive_ops`, we got the next error from `estimate_runtime`: `mat2 must be col_major` from `meta_scaled_mm`.
To fix it, modified `materialize_arg` to also include the stride of the original tensor.

Test Plan: Run float8 training with `activation_memory_budget`.

Differential Revision: D60777297

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132687
Approved by: https://github.com/Chillee
2024-08-05 23:01:35 +00:00