Commit Graph

168 Commits

Author SHA1 Message Date
Yuanyuan Chen
fc8ac1216c [4/N] Remove unused loop variables in tests (#166690)
This PR removes unused loop variables in tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166690
Approved by: https://github.com/justinchuby, https://github.com/mlazos
2025-10-31 10:20:48 +00:00
Boyuan Feng
bebabd7fce [Graph Partition] move custom rules to inductor config (#166458)
This PR adds `custom_should_partition_ops: list[str]` to specify the name of custom ops upon which graph partition happens. It works with cache since it is a `list[str]` in the config file. The op name should be of format "mylib::baz".

Close: #165341

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166458
Approved by: https://github.com/ProExpertProg, https://github.com/eellison, https://github.com/zou3519
2025-10-29 21:43:58 +00:00
Boyuan Feng
1891239a1d [Graph Partition] fix graph partition input signature for fallback kernels (#165815)
Scheduler relies on node.last_usage to free buffers. `last_usage` may contain a buffer that is allocated in previous graph partition AND not directly accessed in the current graph partition.

## Example
```python
def f(x):
    y = x + 1
    z = torch.ops.aten.view.dtype(y, torch.float8_e4m3fn)
    z_cpu = z.cpu()
    u_cuda = z_cpu.cuda()
    return u_cuda
```

In the generated code, we have
```
def partition_0(args):
    ...
    # Topologically Sorted Source Nodes: [y, z], Original ATen: [aten.add, aten.view]
    buf1 = torch.ops.aten.view.dtype(buf0, torch.float8_e4m3fn) # < ------ buf1 is a view of buf0
    buf2 = buf1 # <------- buf2 is buf1
    assert_size_stride(buf2, (8, ), (1, ), 'torch.ops.aten.view.dtype')
    assert_alignment(buf2, 16, 'torch.ops.aten.view.dtype')
    return (buf2, )

def call(self, args):
    ...
    (buf2,) = self.partitions[0](partition0_args)
    ...
    buf3.copy_(buf2, False)
    del buf0
    del buf1
    del buf2  # <---- `del buf2` leads to `del buf0`. BUT `buf0` is not returned from partition_0.
    ...
```

Note: view is treated as a fallback kernel due to its special dtype.
de09bab4b6/torch/_inductor/lowering.py (L841-L843)

## Fix

This PR fixes the issue by also returning these buffers to be freed later.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165815
Approved by: https://github.com/eellison
2025-10-20 22:23:29 +00:00
Yuanyuan Chen
e925dfcc6b Enable all SIM rules except disabled ones (#164645)
`SIM` rules are useful for simplifying boolean expressions and enhances code readability.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164645
Approved by: https://github.com/ezyang, https://github.com/mlazos
2025-10-17 07:27:11 +00:00
Boyuan Feng
f071f17911 [Graph Partition] fix partition x memory plan issue (#165514)
For `test_graph_partition_with_memory_plan_reuse`, before this PR, when using graph partition, it would error ([P1992728479](https://www.internalfb.com/phabricator/paste/view/P1992728479)):

```
def partition_0(args):
    ...
    del buf0
    return (buf3, buf4, buf5, buf2, primals_4, )

...

  File "/tmp/torchinductor_boyuan/ww/cwwc7ukfqscg2vy6ankby2fizdb377tvgyx3fwdgddrxe3g47jg6.py", line 132, in partition_0
    return (buf3, buf4, buf5, buf2, primals_4, )
                              ^^^^
NameError: name 'buf2' is not defined. Did you mean: 'buf0'?
```

When not using graph partition, it would work and give the following code ([P1992997521](https://www.internalfb.com/phabricator/paste/view/P1992997521)):

```
def call(self, args):
    ...
    buf2 = buf0; del buf0  # reuse
    ...
```

Note that the issue is buf0 is not reused for buf2 when using graph partition.

Why? Because the codegen runs `run_wrapper_ir_passes` and `memory_plan_reuse`, which pops tailing `MemoryPlanningLine` unless it is in graph output by checking `V.graph.get_output_names()`. However, for graph partition, we should check the output of the current partition instead of the graph before partition.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165514
Approved by: https://github.com/ProExpertProg, https://github.com/eellison
2025-10-15 21:52:16 +00:00
Jeff Daily
e05c9c0c84 [ROCm][CI] cudagraph trees ut fixes (#163592)
Fixes #162125.
Fixes #160719.
Fixes #157901.
Fixes #157871.
Fixes #157761.
Fixes #157723.
Fixes #157643.
Fixes #157616.
Fixes #157556.
Fixes #157533.
Fixes #157449.
Fixes #157428.
Fixes #157413.
Fixes #157367.
Fixes #157350.
Fixes #157339.
Fixes #157312.
Fixes #157280.
Fixes #157258.
Fixes #157173.
Fixes #157143.
Fixes #157112.
Fixes #157086.
Fixes #157058.
Fixes #157035.
Fixes #156984.
Fixes #156957.
Fixes #156954.
Fixes #156922.
Fixes #156886.
Fixes #156838.
Fixes #156808.
Fixes #156801.
Fixes #156778.
Fixes #156755.
Fixes #156735.
Fixes #156693.
Fixes #152561.
Fixes #130749.
Fixes #100074.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163592
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-09-23 14:45:00 +00:00
rzou
ee7bdd8f2f [graph partition] Add way to register custom rule (#163310)
This PR adds an experimental way to register a custom rule for if
inductor should partition the graph around an operator.

Test Plan:
- new test

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163310
Approved by: https://github.com/ProExpertProg, https://github.com/BoyuanFeng, https://github.com/eellison
ghstack dependencies: #162117, #162307, #162651
2025-09-19 23:28:03 +00:00
Boyuan Feng
4967ad8baa [Graph Partition] improve custom op output alias (#163227)
For a custom op with multiple outputs, we will see the following generated code:
```
buf1 = op1(arg0)
buf3 = buf0[0]
buf4 = buf0[1]
del buf1 # <--- if buf1 is not accessed in the future
```

If `buf1` is not accessed in the future, it's good to deallocate early. So we don't delay `del` until both buf3 and buf4 are not used anymore. Note that buf3 and buf4 hold reference to the data such that `del buf1` does not prevent their usage.

However, when there are mutating args, we don't see `del buf1` immediately.

```python
@torch.library.custom_op(
    "mylib::op1",
    mutates_args=["x"],
    schema="(Tensor(a!)?  x) -> (Tensor, Tensor)",
    device_types="cuda",
)
def op1(x) -> tuple[torch.Tensor, torch.Tensor]:
    x = x + 1
    return (x + 1, x + 2)
```

<img width="661" height="821" alt="image" src="https://github.com/user-attachments/assets/3d1d1f5a-9749-4652-bb02-da593c78702d" />

Why? Because `buf3` is a MultiOutput with `buf1` as input and believes `buf1` (an output of FallbackKernel op1) has inputs that alias output.
72fedf0575/torch/_inductor/ir.py (L7976-L7982)

According to `[NOTE: FallbackKernel supported operators]`, as a mutating op that are auto-functionalizable, buf1's output should NOT alias any of the inputs. This PR improves get_inputs_that_alias_output of Fallback Kernel.

Use case: [moe custom op in vllm](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/layer.py#L2057-L2064)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163227
Approved by: https://github.com/zou3519
2025-09-19 17:01:36 +00:00
Boyuan Feng
333e546c02 [CUDAGraph][UX] warn many times for rerecording from dynamic shapes (#162696)
Excessive re-recording CUDAGraphs lead to bad performance. We previously warns once if this happens.

However, the limit (=50) is too high and users may just observe bad performance before actually seeing the warning message. Even worse, users may not see the warning message when there are many other logs. @anijain2305 reported that he never saw this warning message when using transformer library, but he DOES observe slowdown due to cudagraph re-recording & needs to turn off cudagraph.

#162663 attempts to hard error when re-recording too many times due to dynamic shapes. But it is a bc-breaking change. Actually, hf-t5-generate model in torchbench failed due to 256 re-recordings.

This PR a) reduces to smaller limit (=8); and b) makes the warning more spam, i.e., warn once for every distinct shapes once the limit is reached.

Fixes #162299

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162696
Approved by: https://github.com/mlazos
2025-09-12 06:38:32 +00:00
Boyuan Feng
601ae8e483 [CUDAGraph] add config to error on skipping cudagraph (#161862)
Many users want a config to force all cuda ops captured by cudagraph. When not possible, pt2 should error.

This PR adds `torch._inductor.triton.cudagraph_or_error` for that (default as False). Also added an environment variable `TORCHINDUCTOR_CUDAGRAPH_OR_ERROR` to control.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161862
Approved by: https://github.com/ezyang, https://github.com/mlazos
2025-09-04 15:52:39 +00:00
PyTorch MergeBot
f27985b7e7 Revert "[CUDAGraph] add config to error on skipping cudagraph (#161862)"
This reverts commit 204697f0e6.

Reverted https://github.com/pytorch/pytorch/pull/161862 on behalf of https://github.com/jeanschmidt due to Breaks internal tests, see D81522732 for more details ([comment](https://github.com/pytorch/pytorch/pull/161862#issuecomment-3249582583))
2025-09-03 14:50:44 +00:00
Boyuan Feng
204697f0e6 [CUDAGraph] add config to error on skipping cudagraph (#161862)
Many users want a config to force all cuda ops captured by cudagraph. When not possible, pt2 should error.

This PR adds `torch._inductor.triton.cudagraph_or_error` for that (default as False). Also added an environment variable `TORCHINDUCTOR_CUDAGRAPH_OR_ERROR` to control.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161862
Approved by: https://github.com/ezyang
2025-09-02 15:28:22 +00:00
Boyuan Feng
5f1010fbb3 [Graph Partition] Pass all OSS unit tests (#154667)
Graph partition leads to 6.2% speedup on vision_maskrcnn, 5.8% speedup on yolov3. [P1819700563](https://www.internalfb.com/phabricator/paste/view/P1819700563), 39.5% speedup on speech_transformer inference [P1830602200](https://www.internalfb.com/phabricator/paste/view/P1830602200), 85% speedup on speech_transformer training [P1831115315](https://www.internalfb.com/phabricator/paste/view/P1831115315).

Run the same diff on two days and both show speedup on average.

[first TorchInductor Benchmark ci run](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2021%20Jul%202025%2016%3A37%3A55%20GMT&stopTime=Mon%2C%2028%20Jul%202025%2016%3A37%3A55%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=bf/partition-turn-on&lCommit=75ef90fe89b82c967362a2d40fdf1af047202bc2&rBranch=main&rCommit=abcb24f4de11f8fedf2c2c9ff53b6092ef42306d)
<img width="1885" height="752" alt="image" src="https://github.com/user-attachments/assets/13bba9fc-5dbf-42ad-8558-d54f7e367b41" />

[second TorchInductorBenchmark ci run](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Wed%2C%2023%20Jul%202025%2016%3A38%3A27%20GMT&stopTime=Wed%2C%2030%20Jul%202025%2016%3A38%3A27%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=bf/partition-turn-on&lCommit=66de27e29338c26b1be94733049868cb0309ea52&rBranch=main&rCommit=70d2e9ba455c3c910f6f95b24171c8eee7bc00bf)
<img width="2513" height="1030" alt="image" src="https://github.com/user-attachments/assets/3a413dcb-2314-4292-919a-7ca181f9eeac" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154667
Approved by: https://github.com/eellison
2025-08-12 04:37:58 +00:00
PyTorch MergeBot
09381f5dac Revert "[Graph Partition] Pass all OSS unit tests (#154667)"
This reverts commit ca7315c171.

Reverted https://github.com/pytorch/pytorch/pull/154667 on behalf of https://github.com/clee2000 due to broke inductor/test_memory.py::TestOperatorReorderForPeakMemory::test_reorder_peak_memory_lpmf [GH job link](https://github.com/pytorch/pytorch/actions/runs/16885961204/job/47836769279) [HUD commit link](ca7315c171) note to self: bad TD ([comment](https://github.com/pytorch/pytorch/pull/154667#issuecomment-3176805477))
2025-08-11 20:34:27 +00:00
Boyuan Feng
ca7315c171 [Graph Partition] Pass all OSS unit tests (#154667)
Graph partition leads to 6.2% speedup on vision_maskrcnn, 5.8% speedup on yolov3. [P1819700563](https://www.internalfb.com/phabricator/paste/view/P1819700563), 39.5% speedup on speech_transformer inference [P1830602200](https://www.internalfb.com/phabricator/paste/view/P1830602200), 85% speedup on speech_transformer training [P1831115315](https://www.internalfb.com/phabricator/paste/view/P1831115315).

Run the same diff on two days and both show speedup on average.

[first TorchInductor Benchmark ci run](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2021%20Jul%202025%2016%3A37%3A55%20GMT&stopTime=Mon%2C%2028%20Jul%202025%2016%3A37%3A55%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=bf/partition-turn-on&lCommit=75ef90fe89b82c967362a2d40fdf1af047202bc2&rBranch=main&rCommit=abcb24f4de11f8fedf2c2c9ff53b6092ef42306d)
<img width="1885" height="752" alt="image" src="https://github.com/user-attachments/assets/13bba9fc-5dbf-42ad-8558-d54f7e367b41" />

[second TorchInductorBenchmark ci run](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Wed%2C%2023%20Jul%202025%2016%3A38%3A27%20GMT&stopTime=Wed%2C%2030%20Jul%202025%2016%3A38%3A27%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=bf/partition-turn-on&lCommit=66de27e29338c26b1be94733049868cb0309ea52&rBranch=main&rCommit=70d2e9ba455c3c910f6f95b24171c8eee7bc00bf)
<img width="2513" height="1030" alt="image" src="https://github.com/user-attachments/assets/3a413dcb-2314-4292-919a-7ca181f9eeac" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154667
Approved by: https://github.com/eellison
2025-08-11 16:25:12 +00:00
ghostspiders
af10f1f86c Fix requires_cuda to requires_cuda_and_triton (#160222)
Fixes ##159399

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160222
Approved by: https://github.com/janeyx99
2025-08-10 07:05:52 +00:00
gaoyvfeng
50f23ff6f8 rename-HAS_CUDA-to-HAS_CUDA_AND_TRITON (#159883)
Fixes #159399
"Modified torch.testing._internal.inductor_utils and test/inductor"

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159883
Approved by: https://github.com/janeyx99
2025-08-08 15:44:52 +00:00
Markus Hoehnerbach
57f738b635 [inductor] move all cpu scalars using pinned memory for graph partition (#155360) (#158983)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158983
Approved by: https://github.com/eellison
ghstack dependencies: #158758
2025-08-07 17:07:26 +00:00
PyTorch MergeBot
1fad16aacb Revert "[inductor] move all cpu scalars using pinned memory for graph partition (#155360) (#158983)"
This reverts commit 444e2381d0.

Reverted https://github.com/pytorch/pytorch/pull/158983 on behalf of https://github.com/davidberard98 due to I need to revert #158462 (it causes device-side asserts), and this PR causes a merge conflict in the test file. Sorry about that! ([comment](https://github.com/pytorch/pytorch/pull/158758#issuecomment-3152490371))
2025-08-04 21:47:11 +00:00
Markus Hoehnerbach
444e2381d0 [inductor] move all cpu scalars using pinned memory for graph partition (#155360) (#158983)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158983
Approved by: https://github.com/eellison
ghstack dependencies: #158758
2025-08-04 21:42:05 +00:00
Boyuan Feng
6b9473469f [Graph Partition] add log for graph partition reasons and #partitions (#159425)
Previously, we log `skipping cudagraphs due to [xxx reasons]` when there are cudagraph-unsafe ops. With graph partition, we will split off these ops and cudagraph remaining parts. But the log message is also skipped.

In this PR, we add logs for graph partition reasons and the number of partitions to better understand the workload.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159425
Approved by: https://github.com/eellison
2025-07-31 04:21:06 +00:00
Xuehai Pan
17687eb792 [BE][4/6] fix typos in test/ (test/inductor/) (#157638)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157638
Approved by: https://github.com/yewentao256, https://github.com/jansel
2025-07-06 06:34:25 +00:00
Xuehai Pan
f5e6e52f25 [BE][PYFMT] migrate PYFMT for test/inductor/ to ruff format (#148186)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148186
Approved by: https://github.com/jansel
2025-06-24 11:12:11 +00:00
Boyuan Feng
1044934878 [CUDAGraph] add config cudagraph_capture_sizes (#156551)
Users may want CUDAGraph for certain sizes and fallback for other sizes.

As discussed in Issue #121968, we would like to use cudagraph for [batch size [1,2,3,...,16]](https://github.com/pytorch/pytorch/issues/121968#issuecomment-2259942345) and fallback for others.

Another use case is [vllm](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/cuda_piecewise_backend.py#L114-L119), where 67 batch sizes (i.e., [1,2,4,8,16,24,32,...,512]) are captured and all other sizes fallback.

This PR implements the feature with `torch._inductor.config.triton.cudagraph_capture_sizes`. When it is specified, we only capture cudagraph for these shapes. When it is None (by default), we capture cudagraph for all shapes.

Example:
```python
import torch

torch._inductor.config.triton.cudagraph_capture_sizes = [(2,3), (4,5), (6, 2), (7,3)]

def f(x):
    return x + 1

f = torch.compile(f, mode="reduce-overhead", dynamic=False)

def run(batch_size, seq_len, d):
    x = torch.randn((batch_size, seq_len, d), device="cuda")
    # Need to mark the dimension as dynamic. Automated-dynamic
    # may have some ux issues on matching `cudagraph_capture_sizes`
    # with the actual dynamic shapes, since there are specialization and
    # multiple dynamo graphs.
    torch._dynamo.mark_dynamic(x, 0)
    torch._dynamo.mark_dynamic(x, 1)
    for _ in range(3):
        f(x)

for i in range(2, 10):
    for j in range(2, 10):
        run(i, j, 8)

num_cudagraph = torch._inductor.cudagraph_trees.get_container(0).tree_manager.new_graph_id()
assert num_cudagraph.id == 4
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156551
Approved by: https://github.com/bobrenjc93
2025-06-24 05:14:49 +00:00
Boyuan Feng
f34ab1628b [Graph Partition] move cpu scalar tensor to gpu (#154464)
cudagraph does not support cpu tensors. In this PR, we update the graph by explicitly moving cpu tensors to gpu when profitable, relying on graph partition to split off this data copy, and cudagraphifying the remaining gpu ops.

This PR unblocked the graph partition + cudagraph on speech_transformer, leading to 39.5% speedup on inference [P1830602200](https://www.internalfb.com/phabricator/paste/view/P1830602200), 85% speedup on training [P1831115315](https://www.internalfb.com/phabricator/paste/view/P1831115315).

Close: #119241

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154464
Approved by: https://github.com/eellison, https://github.com/mlazos
2025-06-11 10:22:45 +00:00
PyTorch MergeBot
2c1a93a0ae Revert "[Graph Partition] move cpu scalar tensor to gpu (#154464)"
This reverts commit c1f531f0b0.

Reverted https://github.com/pytorch/pytorch/pull/154464 on behalf of https://github.com/clee2000 due to some of the newly added tests are failing internally, along with some other tests, D75913292 ([comment](https://github.com/pytorch/pytorch/pull/154464#issuecomment-2957201054))
2025-06-09 22:43:20 +00:00
Boyuan Feng
c1f531f0b0 [Graph Partition] move cpu scalar tensor to gpu (#154464)
cudagraph does not support cpu tensors. In this PR, we update the graph by explicitly moving cpu tensors to gpu when profitable, relying on graph partition to split off this data copy, and cudagraphifying the remaining gpu ops.

This PR unblocked the graph partition + cudagraph on speech_transformer, leading to 39.5% speedup on inference [P1830602200](https://www.internalfb.com/phabricator/paste/view/P1830602200), 85% speedup on training [P1831115315](https://www.internalfb.com/phabricator/paste/view/P1831115315).

Close: #119241

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154464
Approved by: https://github.com/eellison
2025-06-07 06:59:39 +00:00
Boyuan Feng
d969e2ec33 [CUDAGraph Trees] support memory allocation on side stream (#152472)
I tried `beginAllocateToPool` instead of `_cuda_beginAllocateCurrentStreamToPool` and the error in #151199 does not happen any more.

However, this approach is unsafe for multithreading. When multiple run_eager happens concurrently, we expect memory allocation to different mem_pool. Since beginAllocateToPool does not check stream, these memory allocation may happen on the same mem_pool.

So, I use `_cuda_beginAllocateCurrentThreadToPool` to direct all memory allocation on the same thread to a given mem_pool. In particular, `_cuda_beginAllocateCurrentThreadToPool` records the launching thread id, and during runtime checks if the current thread id matches the launching thread id.

Fixes #151199

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152472
Approved by: https://github.com/eellison, https://github.com/ngimel
2025-05-02 04:26:35 +00:00
PyTorch MergeBot
56039b5778 Revert "[CUDAGraph Trees] support memory allocation on side stream (#152472)"
This reverts commit c620763ec2.

Reverted https://github.com/pytorch/pytorch/pull/152472 on behalf of https://github.com/BoyuanFeng due to should use tid instead pid ([comment](https://github.com/pytorch/pytorch/pull/152472#issuecomment-2843491656))
2025-04-30 22:18:10 +00:00
Boyuan Feng
c620763ec2 [CUDAGraph Trees] support memory allocation on side stream (#152472)
I tried `beginAllocateToPool` instead of `_cuda_beginAllocateCurrentStreamToPool` and the error in #151199 does not happen any more.

However, this approach is unsafe for multithreading. When multiple run_eager happens concurrently, we expect memory allocation to different mem_pool. Since beginAllocateToPool does not check stream, these memory allocation may happen on the same mem_pool.

So, I use `_cuda_beginAllocateCurrentThreadToPool` to direct all memory allocation on the same thread to a given mem_pool. In particular, `_cuda_beginAllocateCurrentThreadToPool` records the launching thread id, and during runtime checks if the current thread id matches the launching thread id.

Fixes #151199

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152472
Approved by: https://github.com/eellison
2025-04-30 17:45:07 +00:00
Brian Hirsh
4a63cab624 [cudagraphs] Fix issue in collecting static_input_idxs (#152287)
related to https://github.com/pytorch/pytorch/issues/152275

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152287
Approved by: https://github.com/bdhirsh, https://github.com/eellison

Co-authored-by: Brian Hirsh <hirsheybar@fb.com>
2025-04-30 03:24:05 +00:00
Boyuan Feng
797768cd90 [Graph Partition] reorder for minimal number of partitions (#151968)
This pr adds an optimal reordering for minimizing #partitions.

## Optimal reordering for minimizing #partitions

A bfs could minimize #partitions (ignore peak memory for now):
1. For each node, compute node_to_indegree: dict[node, int].
2. Maintain 2 queues: cudagraphable_nodes, and non_cudagraphable_nodes. Iterate through all nodes and add nodes to one of these 2 queues if node_to_indegree[node] == 0.
3. While non_cudagraphable_nodes is not empty: Pop 1 node, schedule it, update the indegree of all its successors, and add its successor nodes to one of the queues if node_to_indegree[successor] == 0.
4. While cudagraphable_nodes is not empty: Pop 1 node, schedule it, update the indegree of all its successors, and add its successor nodes to one of the queues if node_to_indegree[successor] == 0.
5. Repeat step 3 & 4 until all nodes have been scheduled.

We call this strategy `reorder_for_minimizing_partition`.

**Q: Why is this optimal?**

Suppose this is not optimal, we have a counter example with 2 non_cudagraphable regions:

```
[non_cudagrable1, cudagraphable2, non_cudagraphable3]
```

where we can reorder to only 1 non_cudagraphable region:

```
[non_cudagrable1, non_cudagraphable3, cudagraphable2]
```

This reorder means non_cudagraphable3 does not depend on cudagraphable2. So after we scheduled non_cudagraphable1, both non_cudagraphable3 and cudagraphable2 have in_degree as 0. If this is true, Step 3 should have already scheduled non_cudagraphable3 before cudagraphable2 such that the counter example cannot exist.

This shows we cannot find such a counter example and the bfs is optimal on minimizing #partitions.

## Minimize peak memory

`reorder_for_peak_memory` currently uses topological_sort_dfs, topological_sort_lpmf, and topological_sort_bfs, where the later 2 are bfs. ILP brings small benefits and it can hardly scale to more than 100 nodes, according to @xuanzhang816. So ILP is not used for peak memory reorder in the inductor.

Heuristics strategy:
- Conduct reorder_for_peak_memory as the default order
- Conduct reorder_for_minimal_partitions and get results as list[tuple[partition, bool]], where partition: list[BaseSchedulerNode] and bool for cudagraphable.
- If the reorder increases peak memory too much, we use the default order.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151968
Approved by: https://github.com/eellison
2025-04-29 17:17:16 +00:00
PyTorch MergeBot
a6d19fcfac Revert "[cudagraphs] Fix issue in collecting static_input_idxs (#152287)"
This reverts commit 75a564608a.

Reverted https://github.com/pytorch/pytorch/pull/152287 on behalf of https://github.com/wdvr due to causing ao failures - discussed with author ([comment](https://github.com/pytorch/pytorch/pull/152287#issuecomment-2837686127))
2025-04-29 06:57:06 +00:00
Animesh Jain
75a564608a [cudagraphs] Fix issue in collecting static_input_idxs (#152287)
related to https://github.com/pytorch/pytorch/issues/152275

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152287
Approved by: https://github.com/bdhirsh, https://github.com/eellison
2025-04-28 23:07:52 +00:00
Boyuan Feng
99b6c426a9 [Graph Partition] fix extra reference in runner.partitions to cudagraphify functions (#152066)
When CompiledFxGraph is deallocated, its cudagraphifed fn (i.e., `current_callable`) is expected to also be deallocated.
Without graph partition, this is true since the cudagraphified fn is only refered by compiled_fx_graph.current_callable.

However, with graph partition, runner.partitions hold cudagraphified fns while compiled_fx_graph.current_callable holds the runner.call. Thus the cudagraphied fn may not be deallocated when CompiledFxGraph is deallocated. This leads to errors in several unit tests (e.g., test_unaligned_static_input_no_cudagraphs and test_unaligned_static_input_non_trees).

In this PR, we also clean up runner.partitions when CompiledFxGraph is deallocated. This fixes the issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152066
Approved by: https://github.com/eellison
2025-04-28 20:38:26 +00:00
eellison
a5f2fd1017 Unskip index_put in cudagraphs (#152186)
The repro from the original skip in https://github.com/pytorch/pytorch/pull/105439 does not fail. unskip.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152186
Approved by: https://github.com/Skylion007
2025-04-25 18:15:49 +00:00
PaulZhang12
3ed5f1fb77 [CUDA][cuBLAS] Aten GEMM overload for FP32 output from FP16/BF16 inputs (#150812)
Enable FP32 output from FP16/BF16 GEMMs in aten with cuBLAS. Accumulation for these GEMMs are generally already done in FP32. Adds the functionality to the following aten operators:
* mm
* bmm
* addmm
* baddmm

Follow up of customer issue: https://github.com/pytorch/pytorch/issues/146241#issuecomment-2781889390

Differential Revision: [D73126191](https://our.internmc.facebook.com/intern/diff/D73126191)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150812
Approved by: https://github.com/ngimel, https://github.com/eqy
2025-04-18 01:53:26 +00:00
Boyuan Feng
5b5399bfcd [graph partition] reorder to reduce #partitions for simple dependencies (#150814)
This PR reduces #graph partitions by reordering nodes when the `should_partition` nodes have simple dependencies. Specifically, for `should_partition` nodes:
    a. If a node has no dependency or only depends on graph inputs: move to the front. Use case is when we move symints to cuda tensor for PaddedTensorSubclass
    b. If the only user of a node is OutputNode: move it to the end.

#### Example

The following example shows a padded tensor subclass use case where we copy symint to a cuda tensor (aka mask) in the middle of function. Reordering still generates 1 cudagraph by moving the mask to the front.

```python
import torch

torch._inductor.config.graph_partition = True

# Two reasons for this:
# 1. We want to reuse the same mask for many masked_fill calls
# 2. Prevent inductor from fusing this op into other ops (e.g. masked_fill)
#    so we can still reorder in scheduler
@torch.library.custom_op("mylib::create_mask", mutates_args=(), tags=(torch._C.Tag.cudagraph_unsafe,))
def create_mask(padded_size: int, original_size: int, device: torch.device) -> torch.Tensor:
    mask = torch.zeros((padded_size,), dtype=torch.bool, device=device)
    mask[original_size:] = True
    return mask

@create_mask.register_fake
def _(padded_size, original_size, device):
    return torch.empty((padded_size,), dtype=torch.bool, device=device)

def f(padded_tensor, original_tensor, weight):
    original_size = original_tensor.size()[0]
    padded_size = padded_tensor.size()[0]

    # element wise op so we don't care padding value
    padded_tensor = padded_tensor + 1
    padded_tensor = torch.nn.functional.relu(padded_tensor)

    # dot product requires padding with 0
    dot_res = padded_tensor.dot(weight)
    padded_tensor += dot_res

    # min requires padding with inf, so we create mask now
    mask = create_mask(padded_size, original_size, padded_tensor.device)
    min_res = torch.min(
        torch.ops.aten.masked_fill(padded_tensor, mask, float("inf"))
    )

    # max requires padding with inf. we can reuse previous mask
    max_res = torch.max(
        torch.ops.aten.masked_fill(padded_tensor, mask, -float("inf"))
    )

    return min_res+max_res+padded_tensor

compiled_f = torch.compile(f, mode="reduce-overhead")

def run(padded_size, original_size):
    padded_tensor = torch.randn(padded_size, device="cuda")
    padded_tensor[original_size:] = 0
    original_tensor = torch.randn(original_size, device="meta")

    weight = torch.randn(padded_size, device="cuda")
    eager_out = f(padded_tensor, original_tensor, weight)
    compiled_out = compiled_f(padded_tensor, original_tensor, weight)
    assert torch.allclose(eager_out[0], compiled_out[0])
    assert torch.allclose(eager_out[1], compiled_out[1])

# new cudagraph
run(8, 4)

# new cudagraph due to recompile
run(8, 6)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150814
Approved by: https://github.com/eellison
2025-04-16 20:49:20 +00:00
Boyuan Feng
c1470d4dc4 [graph partition] support graphsafe_run_with_rng_state (#150958)
Prior to this PR, `rng_state` is in `V.graph.graph_inputs` but not in read_writes of any IRNode. As a result, it is not identified as a partition inputs:
```python
def partition_0(args):
    primals_2, primals_1 = args
    ...
    buf0 = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype=torch.float32, device=device(type='cuda', index=1), pin_memory=False, rng_state=fwd_rng_state_0)
    # <----- access fwd_rng_state_0 but it's not an input
    ...

def call(self, args):
    primals_1, primals_2, fwd_rng_state_0 = args
    ...
    partition0_args = [primals_2, primals_1]
    (buf2, primals_2, primals_1) = self.partitions[0](partition0_args)
     # <---- fwd_rng_state_0 is graph_inputs but is not passed to partitions[0]
     ...
```

This PR fixes this issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150958
Approved by: https://github.com/eellison
2025-04-12 03:17:08 +00:00
Boyuan Feng
3f54b14c75 [CUDAGraph] support meta tensor (#150478)
Previously, cudagraph is skipped if the graph contains any meta tensor. However, we should not skip since meta tensor does not have actual computation. This PR fixes the issue.

### Example

```python
import torch

def foobar(x, y):
    return x * 2, y * 3

foo_c = torch.compile(mode="reduce-overhead")(foobar)
t = torch.empty((1, 16, 128, 128), device="meta")
y = torch.rand([64], device="cuda")

eager_out = foobar(t, y)

for _ in range(3):
    compiled_out = foo_c(t, y)
```

Prior to this PR, above code leads to
```
skipping cudagraphs due to multiple devices: device(type='cuda', index=0), device(type='meta')
```

With this PR, we don't skip.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150478
Approved by: https://github.com/eellison
2025-04-02 07:21:50 +00:00
Alexander Grund
350a479146 Fix test failures on non-x86 Linux (#148445)
The cpp contexts are only supported on x86 Linux.
The tests requiring them are skipped on non-Linux but not if the architecture is not x86.
In most places it is checked for ARM64 which is not enough as a check for x86 is required instead.

Fix the test decorators and factor out a common one in test_cuda.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148445
Approved by: https://github.com/eellison
2025-03-28 15:27:44 +00:00
Boyuan Feng
c830d750e6 [graph partition] support splitting on custom ops (#149782)
This PR adds support for graph partition on custom ops. Land after #149458.

### API
This PR provides a new API to register/unregister custom ops for graph partition.

```python
def register_custom_op_support_cudagraph(
    operator: torch._library.custom_ops.CustomOpDef,
    is_cudagraphable: bool,
) -> None
```

Example usage:

```python
from torch._inductor.utils import register_custom_op_partition

@torch.library.custom_op("mylib::movement", mutates_args=())
def movement(pic: torch.Tensor) -> torch.Tensor:
    img = pic.cpu()
    cropped_img = (img + 1) * 2
    return cropped_img.cuda() / 255.0

@movement.register_fake
def _(pic):
    return torch.empty_like(pic)

register_custom_op_support_cudagraph(movement, is_cudagraphable=False)
```

### Example
In this example, 1 torch-compiled region has 3 cudagraphs after splitting on 2 custom ops.

![image](https://github.com/user-attachments/assets/6d07355b-6690-4cde-89ef-e4aff6b0079c)

Code to repro:
```python
import torch
from torch._inductor.utils import register_custom_op_support_cudagraph

torch._inductor.config.graph_partition = True

@torch.library.custom_op("mylib::movement", mutates_args=())
def movement(pic: torch.Tensor) -> torch.Tensor:
    img = pic.cpu()
    cropped_img = (img + 1)*2
    return cropped_img.cuda() / 255.

@movement.register_fake
def _(pic):
    return torch.empty_like(pic)

@torch.library.custom_op("mylib::modify", mutates_args=())
def modify(pic: torch.Tensor) -> torch.Tensor:
    pic1 = pic + 1
    pic1_cpu = (pic1.cpu() + 1) * 2
    return pic1_cpu.cuda() + pic

@modify.register_fake
def _(pic):
    return torch.empty_like(pic)

@torch.library.custom_op("mylib::transform", mutates_args=())
def transform(pic: torch.Tensor) -> torch.Tensor:
    return (pic + 1) * 2

@transform.register_fake
def _(pic):
    return torch.empty_like(pic)

register_custom_op_support_cudagraph(movement, is_cudagraphable=False)
register_custom_op_support_cudagraph(modify, is_cudagraphable=False)

img = torch.randn(3, 64, 64, device="cuda")

def f(img):
    x = (img + 10) * 2
    y = movement(x)
    z = y + 1
    u = transform(z)
    v = 2*u + 1
    out = modify(v)
    return out + 1

compiled_f = torch.compile(f, mode="reduce-overhead", fullgraph=True)

eager_out = f(img)

for _ in range(3):
    compiled_out = compiled_f(img)
    assert torch.allclose(eager_out, compiled_out)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149782
Approved by: https://github.com/zou3519
2025-03-27 16:23:07 +00:00
Boyuan Feng
039ebdc192 [Graph Partition] Support symbol inputs (#149458)
This PR supports symbol inputs to graph partition functions. Before this PR, we rely on `node.read_writes` to get partition inputs. However, this does not cover symbol inputs.

In this PR, for each graph partition, we collect all symbol inputs which are required to be in scope to successfully         perform codegen, including:
- free symbols used in partition nodes.
- free symbols in partition input/node shapes, strides, and offsets. This is needed for recording cudagraphs for tensors with dynamic shapes.

### Note1: MutationLayout
In this example, node.layout is MutationLayoutSHOULDREMOVE. The symint from index `n` does not appear in the size, offset, stridese of node.layout. This symint appear in node.layout.target. So we need extra handle for it.

```python
x = torch.zeros(7, device="cuda")

def fn(n, a):
    a[n] = -1
    return a

opt_fn = torch.compile(fn, fullgraph=True)

for n in range(2, x.shape[0]):
    opt_fn(n, x)
```

### Note2: Composability with Padded Tensor Subclass

W/o graph partition, Padded Tensor subclass lifts outer shapes to input arguments (i.e., arg0_1 for s0, arg1_1 for s1) but does not lift inner shapes (i.e., s2 and s3). Since cudagraph cache relies on integer inputs, it will cache on outer shapes and ignore inner shapes, which is bad.

```
def call(args):
    arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1 = args
    args.clear()
    s0 = arg0_1
    s1 = arg1_1
    arg2_1_size = arg2_1.size()
    s2 = arg2_1_size[0]
    s3 = arg2_1_size[1]
    assert_size_stride(arg2_1, (s2, s3), (s3, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf0 = empty_strided_cuda((s2, s3), (s3, 1), torch.float32)
        # Topologically Sorted Source Nodes: [x1, mul], Original ATen: [aten.add, aten.mul]
        triton_poi_fused_add_mul_0_xnumel = s2*s3
        stream0 = get_raw_stream(0)
        triton_poi_fused_add_mul_0.run(arg2_1, buf0, triton_poi_fused_add_mul_0_xnumel, stream=stream0)
        del arg2_1
    return (buf0, s0, s1, s1, )
```

w/ graph partition, the partition function only includes tensor and inner shapes as inputs, to make sure the cudagraph caching is correct. Full Comparison: [code](https://www.internalfb.com/intern/diffing/?paste_number=1761674743)
```python
   def call(self, args):
        arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1 = args
        args.clear()
        s0 = arg0_1
        s1 = arg1_1
        arg2_1_size = arg2_1.size()
        s2 = arg2_1_size[0]
        s3 = arg2_1_size[1]
        assert_size_stride(arg2_1, (s2, s3), (s3, 1))
        partition0_args = [arg2_1, s2, s3]
        del arg2_1
        (buf0,) = self.partitions[0](partition0_args)
        del partition0_args
        return (buf0, s0, s1, s1, )
```

The number of cudagraphs is validated below: (also added to test)
```python
import torch

from padded_tensor import PaddedTensor

# Turning off graph_partition leads to
# torch._inductor.cudagraph_trees.get_container(0).tree_manager.new_graph_id().id=6
# at the end, which is wrong.
# torch._inductor.config.graph_partition = False

# Turning on graph_partition leads to
# torch._inductor.cudagraph_trees.get_container(0).tree_manager.new_graph_id().id=4
# at the end, which is correct.
torch._inductor.config.graph_partition = True

def f(x):
    x1 = x + 1
    return x1 * 2

compiled_f = torch.compile(f, mode="reduce-overhead")

def run(shape):
    x = torch.randn(*shape, device="cuda")
    pad_x = PaddedTensor.from_tensor(x, multipliers={0:4, 1:4})
    assert hasattr(pad_x, "multipliers"), breakpoint()
    eager_out = f(pad_x)

    for _ in range(3):
        compiled_out = compiled_f(pad_x)
    compiled_out = compiled_f(pad_x)

    assert eager_out.shape == compiled_out.shape
    assert eager_out.tensor.shape == compiled_out.tensor.shape
    assert torch.allclose(eager_out.tensor, compiled_out.tensor)

# static shape. record a NEW cudagraph. 1 cudagraph in total now.
run((2,3))
# outer shape is dynamic, leading to a new dynamo graph
# this new dynamo graph forces a NEW cudagraph. 2 cudagraphs in total now
run((3,4))
# outer shape changed but inner shape does not change
# so NO new cudagraph is recorded
run((2,2))
# inner shape is dynamic now, leading to a new dynamo graph
# this new dynamo graph forces a NEW cudagraph. 3 cudagraphs in total now
run((5,6))
# does NOT record a new cudagraph
run((7,8))
# record a NEW cudagraph. 4 cudagraphs in total now
run((10,11))

assert torch._inductor.cudagraph_trees.get_container(0).tree_manager.new_graph_id().id == 4
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149458
Approved by: https://github.com/eellison
2025-03-26 17:21:30 +00:00
James Wu
fe954cdcbf Use correct boxed_forward_device_index when running CompiledFxGraph.post_compile (#148130)
This PR threads through the correct boxed_forward_device_index from graph_kwargs to CompiledFXGraph.post_compile. This allows us to correctly update BoxedDeviceIndex from cache hits.

We don't actually need to save `boxed_forward_device_index` in CompiledFXGraph because its value is in the cache key, so it always matches to the ambient one anyway. On forward with cudagraphs enabled, derive `boxed_forward_device_index`'s value from `device_idxs`.

Testing:

```
python benchmarks/dynamo/cachebench.py --mode training --benchmark torchbench --model BERT_pytorch --device cuda --repeat 1 --dynamic --output="dynamic.json"
```

Now cache hits properly on FXGraphCache. AOTAutogradCache has a guard failure. Will look into that as a followup.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148130
Approved by: https://github.com/eellison
2025-03-23 02:57:58 +00:00
eqy
6048d88afe [ARM64][CUDA] skip string pattern matching in test_workspace_allocation_error (#149236)
`unwind()` on ARM64 seems to elide the strings of interest

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149236
Approved by: https://github.com/malfet, https://github.com/eellison, https://github.com/BoyuanFeng
2025-03-17 00:30:43 +00:00
Boyuan Feng
3e605fe46d [CUDAGraph] Graph Partition (#147648)
This PR implements cudagraph partition, following previous PR on inductor graph partition (#147038). Since there are many ops that cudagraph cannot support, this PR focuses on `cpu ops` and will add more partition rules in the next PR.

## Example
```python
import torch

torch._inductor.config.graph_partition = True

def f(x, y):
    x1 = x + 1
    y1 = y + 1
    y_cpu = y1.cpu() + 1
    z = x @ y
    return x1 + y1 + z + y_cpu.cuda()

x, y = [torch.ones(2, 2, device="cuda") for _ in range(2)]
x_cloned, y_cloned = [tmp.clone() for tmp in [x,y]]
eager_out = f(x, y)

f_compiled = torch.compile(f, mode="reduce-overhead")

for _ in range(5):
    compiled_out = f_compiled(x_cloned, y_cloned)
    assert torch.allclose(eager_out, compiled_out)
```

w/o graph partition, we will skip cudagraph:
```
skipping cudagraphs due to skipping cudagraphs due to cpu device (device_put). Found from :
   File "/home/boyuan/playground/cudagraph/graph_partition/graph_partition.py", line 9, in f
    y_cpu = y1.cpu() + 1 # 3
```

w/ graph partition, we can see two cudagraphify under the same torch-compiled region:
![image](https://github.com/user-attachments/assets/4e22d428-2687-433d-b92a-0814a2201b25)

## Design

PR #147038 splits `def call(args)` function into multiple `def partition_id(args)`. In this PR, we use `recursively_apply_fns()` to wrap each `partition_id()` function with `cudagraphify`. One major design point is, `cudagraphify` takes metadata such as static_input_idxs and we need to provide such metadata for each graph partition. However, we previously only have such metadata for the original graph instead of graph partitions.

The [idea](https://github.com/pytorch/pytorch/pull/147038#discussion_r1964124800) is:
- compute a mapping from the partition metadata (e.g., input/output idx) to the graph metadata, stored in `GraphPartitionMap`.
- during post_compile, get the `CudagraphMetadata` for each partition based on the graph-level metadata and `GraphPartitionMap`, via `get_partition_cudagraph_metadata()`.
- finally, in `cudagraph_partition_pos_compile`, we compute the `CudagraphMetadata` and apply cudagraphify for each graph via `recursively_apply_fns`.

#### Q: How does it work with codecache?

While we have multiple graph partitions, we still have 1 file and 1 `call` function for 1 dynamo graph. The major difference is we need to additionally load a `recursively_apply_fns()` for graph partition. We also add `partition_maps: Optional[list[GraphPartitionMap]]` to `CompiledFxGraph` so it will be serialized and could be deserialized later.

## Edge Case 1
PyTorch has an assumption on input/output orders. For example, backward inputs take saved tensors first and then tangents. In graph partition, we respect such orders via `graph_partition_signature_reorder`.

## Edge Case 2
Cudagraphifying `call` function gives 2 cudagraph managed tensors `buf0` and `primals_1`. However, cudagraphifying `partition_0` gives only 1 cudagraph managed tensor `buf0`. This leads to a semantic difference between cudagraph w/ and w/o graph partition. [full code comparison](https://www.internalfb.com/intern/diffing/?paste_number=1747654420)

![image](https://github.com/user-attachments/assets/03d08ce0-f1d1-4d1d-8432-805a07e1dd40)

To achieve the same semantic, we returns an input tensor as output if it is not freed in a graph partition. This allows more cudagraph managed tensors and is important for handling saved tensors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147648
Approved by: https://github.com/eellison
2025-03-13 16:00:21 +00:00
eellison
a7fe685be8 Add cpp wrapper skip to cudagraph logs (#148700)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148700
Approved by: https://github.com/jbschlosser
2025-03-07 01:02:40 +00:00
cyy
b7832f0339 Enable ASAN in CUDA tests (#147812)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147812
Approved by: https://github.com/janeyx99
2025-03-04 02:50:39 +00:00
eellison
481a57bc37 Support torch.compile rng selective activation checkpointing with cudagraph (#146878)
TODO:
- [x]  Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync
- [x] Update rng state initialization to take from correct device
- [x]  Tests
- [x] handling of retain_graph
- [x] respect fallback random

Fix for https://github.com/pytorch/pytorch/issues/130123.

Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states.

We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward.

```
 ===== Forward graph 1 =====
 /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0);  fwd_rng_state_0 = None
        ...

 ===== Backward graph 1 =====
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0);  bwd_rng_state_0 = None
```

There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls:
- fwd0: fwd_rng_state0 -> fwd_rng_state1
- fwd1: fwd_rng_state1 -> fwd_rng_state2
- bwd1
- bwd0

Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary.

Other notes:

Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order.

Questions for reviewers:

This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`.

Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set

I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts.

Edit: updated to be taken from randint()

Update: initializing rng states from torch.randint..

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
2025-02-28 00:47:03 +00:00
PyTorch MergeBot
17358ce778 Revert "Support torch.compile rng selective activation checkpointing with cudagraph (#146878)"
This reverts commit ad0c879e22.

Reverted https://github.com/pytorch/pytorch/pull/146878 on behalf of https://github.com/wdvr due to lint failure ([comment](https://github.com/pytorch/pytorch/pull/146878#issuecomment-2686767956))
2025-02-27 03:36:16 +00:00