mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Previous work #158352 delivered CUDAGraph memory footprint reduction with no replay-time impact, but capture time regressed (up to 20× slower) due to repeated full-graph traversals. See previous benchmark results [here](https://github.com/pytorch/pytorch/pull/158352#issuecomment-3215947565) This PR removes capture/reply overhead while preserving the memory savings: 1. **Terminals as free markers** We stop inserting empty nodes and instead record the current stream terminals as free markers. This avoids mutating the user’s graph and keeps semantics unchanged. 2. **Incremental, cached reachability** We add a **per-graph reuse context** that caches reverse-traversal state: * `graph_reuse_context[graph].visited[stream]` tracks nodes already seen from that stream’s terminal frontier. * On each allocation during capture, we resume traversal from the latest terminals and only visit unseen nodes. * A block is freed when all its recorded markers are in the visited set of its allocation stream—i.e., all markers are proven predecessors of future work. See [the performance results here](https://docs.google.com/spreadsheets/d/e/2PACX-1vRPvdd9Xa8W87ixbiA0da_qvOhrUAjUpFz0G-_j-MsDnoeRyhEa4_ut_W3rqcg1VVZVFJ-gucwov-3b/pubhtml?gid=1468302443&single=true), we sweep synthetic multi-stream CUDA Graphs built by `capture_benchmark.py` (same as before, we generate random interleaving of alloc/free/join with given probabilities, see [gist here](https://gist.github.com/eee4017/e2092d215b1d4bd46534148939af39e3)), and we compare median capture/replay times and memory. On an NVIDIA H100 PCIe across 24 configs, the optimization preserves reserved memory reduction at ~24–98%, leaves allocated memory unchanged, and brings capture time back to baseline (range 0.96–1.04× vs. baseline) with replay time unchanged (range 0.97–1.11×). Pull Request resolved: https://github.com/pytorch/pytorch/pull/162186 Approved by: https://github.com/eqy, https://github.com/ngimel |
||
|---|---|---|
| .. | ||
| amp_examples.rst | ||
| autograd.rst | ||
| broadcasting.rst | ||
| cpu_threading_torchscript_inference.rst | ||
| cuda.rst | ||
| custom_operators.rst | ||
| ddp.rst | ||
| extending.func.rst | ||
| extending.rst | ||
| faq.rst | ||
| get_start_xpu.rst | ||
| gradcheck.rst | ||
| hip.rst | ||
| large_scale_deployments.rst | ||
| libtorch_stable_abi.md | ||
| mkldnn.rst | ||
| modules.rst | ||
| mps.rst | ||
| multiprocessing.rst | ||
| numerical_accuracy.rst | ||
| out.rst | ||
| randomness.rst | ||
| serialization.rst | ||
| windows.rst | ||