Users may want CUDAGraph for certain sizes and fallback for other sizes.
As discussed in Issue #121968, we would like to use cudagraph for [batch size [1,2,3,...,16]](https://github.com/pytorch/pytorch/issues/121968#issuecomment-2259942345) and fallback for others.
Another use case is [vllm](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/cuda_piecewise_backend.py#L114-L119), where 67 batch sizes (i.e., [1,2,4,8,16,24,32,...,512]) are captured and all other sizes fallback.
This PR implements the feature with `torch._inductor.config.triton.cudagraph_capture_sizes`. When it is specified, we only capture cudagraph for these shapes. When it is None (by default), we capture cudagraph for all shapes.
Example:
```python
import torch
torch._inductor.config.triton.cudagraph_capture_sizes = [(2,3), (4,5), (6, 2), (7,3)]
def f(x):
return x + 1
f = torch.compile(f, mode="reduce-overhead", dynamic=False)
def run(batch_size, seq_len, d):
x = torch.randn((batch_size, seq_len, d), device="cuda")
# Need to mark the dimension as dynamic. Automated-dynamic
# may have some ux issues on matching `cudagraph_capture_sizes`
# with the actual dynamic shapes, since there are specialization and
# multiple dynamo graphs.
torch._dynamo.mark_dynamic(x, 0)
torch._dynamo.mark_dynamic(x, 1)
for _ in range(3):
f(x)
for i in range(2, 10):
for j in range(2, 10):
run(i, j, 8)
num_cudagraph = torch._inductor.cudagraph_trees.get_container(0).tree_manager.new_graph_id()
assert num_cudagraph.id == 4
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156551
Approved by: https://github.com/bobrenjc93
These hooks are used by internal stuck job detection to associate compilation events with the compile lease. Previously, we only had events for Dynamo and Inductor compilation. And recently, the callback handler was updated to ignore nested events. So the Inductor event was only really used by lazy backward.
Here, I remove the inductor event, and add an explicit lazy backward one. Additionally, I add other runtime compilation events: autotuning and cudagraphs. I also expose the CompileId as a string to avoid imports, this will let internal UIs track each graph's contribution to the timeout.
```python
class CallbackTrigger(enum.Enum):
# most common case, dynamo attempts to trace a new frame
DYNAMO = 1
# backward compilation can be deferred to runtime
LAZY_BACKWARD = 2
# some backends autotune at runtime
TRITON_AUTOTUNING = 3
# cudagraphs record at runtime
CUDAGRAPH_RECORDING = 4
```
Differential Revision: [D75092426](https://our.internmc.facebook.com/intern/diff/D75092426)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153596
Approved by: https://github.com/masnesral
Fix for https://github.com/pytorch/pytorch/issues/152425
inductor specializes whether or not a tensor is 16-bit aligned on the first invocation. then, on subsequent invocations, if we inferred alignment but are passed a non-aligned tensor we clone the tensor.
If we infer alignment, then run with unaligned, and mutate the input, we need to reflect back the mutation to the input. This pr adds back that mutation.
We could have also been less aggressive about inferring alignment for mutated tensors, but that has a pretty perf hit.See the following benchmark:
```
import torch
t = torch.rand(4096 * 4096, device="cuda", dtype=torch.float16)
@torch.compile(dynamic=False)
def foo(x):
return x.add_(1)
import triton
print(triton.testing.do_bench(lambda: foo(t[:-1])))
torch._dynamo.reset()
print(triton.testing.do_bench(lambda: foo(t[1:])))
```
gives
```
0.04063070610165596
0.07613472988113162
```
So almost twice as slow for non-aligned tensors. Tensors changing alignment is a relatively rare case.
In the future, we could considering a multi-kernel approach, or codegening a triton kernel that does most of the loads with aligned instructions, and a prologue/epilogue of un-alignment. But, it's yet to be seen this is a huge issue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154442
Approved by: https://github.com/bobrenjc93, https://github.com/bdhirsh
I tried `beginAllocateToPool` instead of `_cuda_beginAllocateCurrentStreamToPool` and the error in #151199 does not happen any more.
However, this approach is unsafe for multithreading. When multiple run_eager happens concurrently, we expect memory allocation to different mem_pool. Since beginAllocateToPool does not check stream, these memory allocation may happen on the same mem_pool.
So, I use `_cuda_beginAllocateCurrentThreadToPool` to direct all memory allocation on the same thread to a given mem_pool. In particular, `_cuda_beginAllocateCurrentThreadToPool` records the launching thread id, and during runtime checks if the current thread id matches the launching thread id.
Fixes#151199
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152472
Approved by: https://github.com/eellison, https://github.com/ngimel
I tried `beginAllocateToPool` instead of `_cuda_beginAllocateCurrentStreamToPool` and the error in #151199 does not happen any more.
However, this approach is unsafe for multithreading. When multiple run_eager happens concurrently, we expect memory allocation to different mem_pool. Since beginAllocateToPool does not check stream, these memory allocation may happen on the same mem_pool.
So, I use `_cuda_beginAllocateCurrentThreadToPool` to direct all memory allocation on the same thread to a given mem_pool. In particular, `_cuda_beginAllocateCurrentThreadToPool` records the launching thread id, and during runtime checks if the current thread id matches the launching thread id.
Fixes#151199
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152472
Approved by: https://github.com/eellison
Summary: I'm investigating differences in total torch.compile overhead in our two main internal sources: dynamo_compile and pt2_compile_events. One source of discrepancy is due to cudagraphs overheads. Currently, we have a context manager that optionally attributes a dynamo_timed region to a cudagraph-related column logged to dynamo_compile, but _all_ dynamo_timed regions show up in pt2_compile_events (hence the discrepancy; pt2_compile_events is overcounting). We could filter out these specific events from pt2_compile_events when measuring overall overhead. But I'm going to argue that those timed regions that we DO NOT consider as a compiler-related overhead don't have much value in logging in the first place. So I'm suggesting we just remove those instances.
Here's the production job with the discrepancy:
* dynamo_compile: https://fburl.com/scuba/dynamo_compile/3604eypl
* pt2_compile_events: https://fburl.com/scuba/pt2_compile_events/c2dv8sty
Test Plan:
torchbench nanogpt:
* tlparse: https://fburl.com/h1n2ascc
* dynamo_compile: https://fburl.com/scuba/dynamo_compile/sandbox/u37yrynp
* pt2_compile_events: https://fburl.com/scuba/pt2_compile_events/s7avd0di
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152136
Approved by: https://github.com/BoyuanFeng
Summary: There are a few issues I'm solving:.
1. It's too hard to measure total pt2 overhead using the dynamo_compile table because users need to know the columns representing all the top-level events (dynamo_cumulative_compile_time_us, etc.). Instead, let's populate the existing duration_us field for all top-level events. The complication is that runtime events in particular (Triton autotuning, cudagraphify) can be collapsed into a single row, with gaps in between, so we can't simply use `end_time - start_time` in all cases. Instead, we'll sum durations for all outer events when updating the compile-time or runtime metrics context. Introduce a 'depth' counter in TLS to track the nesting of CompilationMetrics events.
2. The existing implementation relies on callers of dynamo_timed to specify whether the event is a runtime or compile-time event. That doesn't work because some methods can be called in both situations, e.g., `CachingAutotuner.benchmark_all_configs`. For example `TORCHINDUCTOR_BENCHMARK_FUSION=1` enables benchmarking during compile-time. Instead, we can figure out automatically whether we're measuring a compile-time or runtime event and log accordingling.
3. If `log_compilation_events` were to throw an exception, we'd fail to clear the aggregated counters for runtime logs and they could be attributed to the wrong compile ID. I didn't actually find evidence of this in practice, but I added exception handling for extra safety.
Test Plan:
Ran internal models and compared dynamo_compile to pt2_compile_events:
`TORCHINDUCTOR_BENCHMARK_FUSION=0`
* tlparse: https://fburl.com/itciwnxc
* dynamo_compile: https://fburl.com/scuba/dynamo_compile/yvkif5vb
* pt2_compile_events: https://fburl.com/scuba/pt2_compile_events/segijet7
`TORCHINDUCTOR_BENCHMARK_FUSION=1`
* tlparse: https://fburl.com/jgurcvkw
* dynamo_compile: https://fburl.com/scuba/dynamo_compile/uum91ceb
* pt2_compile_events: https://fburl.com/scuba/pt2_compile_events/x4xnisez
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151749
Approved by: https://github.com/Skylion007
Cudagraphs is careful to not allow any memory recorded to escape globally without having a reference to the tensor. This is because we may later reclaim that memory for a cudagraph recording and we need to mark the tensor as erroring on access. Very occasionally, a stray tensor will have been allocated locally but not yet cleaned up. In this case, we enter the slow path and try to gc.collect() to deallocate it. From a hard to repro internal use case, this was fixed by an additional `cuda.synchronize()`.
i also snuck in an outdated comment and a duplicate line removal.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149741
Approved by: https://github.com/BoyuanFeng, https://github.com/Skylion007
Summary: this adds some new dynamo_timed calls in cudagraph_trees, primarily with the aim to add cudagraph-related timing to scuba. Things to note:
* Uses the changes in https://github.com/pytorch/pytorch/pull/141919 to log "runtime" entries
* The logging for chromium/tlparse/scuba relies on us providing a compile_id since it's not available in the environment. A lot of the changes here are just passing around the compile_id
* I believe the spirit of the scuba logging is to capture the overheads of `torch.compile`. Therefore, I'm not adding _every_ dynamo_timed to scuba. For example, "run_eager" is the first real execution of the inductor graph -- it's not cudagraph overhead, per se. Watch out for the two instances of `dynamo_compile_runtime_column_us="runtime_cudagraphify_time_us"`. Those are the spots I believe are _extra_ overhead we'd contribute to torch.compile.
Test Plan:
`python benchmarks/dynamo/torchbench.py --performance --training --amp --backend inductor --device cuda --print-compilation-time --repeat 5 --cold-start-latency --only dcgan`:
* tlparse: https://fburl.com/21yrdn8h
* scuba: https://fburl.com/scuba/dynamo_compile/sandbox/wt90wnjz
`python benchmarks/dynamo/torchbench.py --performance --training --amp --backend inductor --device cuda --print-compilation-time --repeat 5 --cold-start-latency --only nanogpt`
* tlparse: https://fburl.com/r9mp7uiv
* scuba: https://fburl.com/scuba/dynamo_compile/sandbox/1nvx94re
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143220
Approved by: https://github.com/eellison
TODO:
- [x] Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync
- [x] Update rng state initialization to take from correct device
- [x] Tests
- [x] handling of retain_graph
- [x] respect fallback random
Fix for https://github.com/pytorch/pytorch/issues/130123.
Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states.
We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward.
```
===== Forward graph 1 =====
/data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0):
sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0); fwd_rng_state_0 = None
...
===== Backward graph 1 =====
def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0):
sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0); bwd_rng_state_0 = None
```
There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls:
- fwd0: fwd_rng_state0 -> fwd_rng_state1
- fwd1: fwd_rng_state1 -> fwd_rng_state2
- bwd1
- bwd0
Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary.
Other notes:
Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order.
Questions for reviewers:
This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`.
Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set
I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts.
Edit: updated to be taken from randint()
Update: initializing rng states from torch.randint..
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
TODO:
- [x] Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync
- [x] Update rng state initialization to take from correct device
- [x] Tests
- [x] handling of retain_graph
- [x] respect fallback random
Fix for https://github.com/pytorch/pytorch/issues/130123.
Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states.
We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward.
```
===== Forward graph 1 =====
/data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0):
sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0); fwd_rng_state_0 = None
...
===== Backward graph 1 =====
def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0):
sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0); bwd_rng_state_0 = None
```
There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls:
- fwd0: fwd_rng_state0 -> fwd_rng_state1
- fwd1: fwd_rng_state1 -> fwd_rng_state2
- bwd1
- bwd0
Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary.
Other notes:
Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order.
Questions for reviewers:
This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`.
Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set
I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts.
Edit: updated to be taken from randint()
Update: initializing rng states from torch.randint..
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
* Automatically applies ruff rule 401. Turns loops into equivalent list comprehensions which are faster and do not leak the scope of the loop variables.
* list comprehensions not only often have better typing, but are 50+% faster than for loops on overhead. They also preserve length information etc and are better for the interpreter to optimize.
* Manually went back and made mypy happy after the change.
* Also fixed style lints in files covered by flake8 but not by pyfmt
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140980
Approved by: https://github.com/justinchuby, https://github.com/malfet
Summary: Add time log to cudagraph, including `create deferred_cudagraphify wrapper`, `warmup`, `record`, and `checkpoint`.
Test Plan:
1. buck2 run fbcode//mode/opt //pytorch/benchmark:run -- resnet50 -d cuda -t train --inductor --pt2-triton-cudagraph
2. Found the result in [scuba table](https://fburl.com/scuba/pt2_compile_events/0oik8nu9).
{F1954034920}
Differential Revision: D65505659
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139818
Approved by: https://github.com/eellison
Type annotations for compile_fx.
- Some of the stuff here is pretty complicated (functions which return functions that take functions) so I bailed on those and used `Any` just to get the rest landed.
- There are also changes to type signatures in other files which I did just to let mypy know more about the types in compile_fx.py.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138033
Approved by: https://github.com/Skylion007
This PR mostly refactors by putting code into utils files so that they can be shared between codecache.py and compile_fx.py. Afterwards, it then changes compile_fx so that:
- When saving to FXGraphCache, we save onto the CompiledFXGraph all the necessary metadata for running post compile steps (realigning inputs, cudagraphification).
- When loading from FXGraphCache, we use the saved information directly, instead of calculating them from scratch.
What this does is make it so that `FXGraphCache.load()` is a perfect cache on compile_fx_inner, in that it **returns exactly what compile_fx_inner returns**. This also makes it possible for AOTAutogradCache, given a key to the fx graph cache and example inputs, to get back the full return value of compile_fx_inner.
## What's a post compile step?
We define a **post-compile** to be the set of actions that need to run after FXGraphCache either loads from the cache or misses and runs compilation. These steps include:
- Setting the tracing context's output strides
- Running cudagraphs if enabled
- Maybe realign inputs if cudagraphs didn't run
To run these steps, we save all the necessary metadata in CompiledFxGraph, and use them on a cache hit to reconstruct the object.
## Splitting cudagraphs work into pre/post compile
Cudagraphs does a lot of work on the input graph module to determine if cudagraphs can be enabled. This is the code that involves cudagraph_tests and stack traces. This will work in a world where we have access to the input graph module, but with AOTAutograd warm start, we won't have access to that information anymore. Therefore we can split cudagraphs work into two parts: on a cache miss (and therefore a full compile), we do the cudagraphs testing work, and save cudagraph_fail_reasons into the cache. Then on a cache hit, we know whether or not we can run cudagraphs, and if we can't, we can emit the correct error messages.
Implementation notes:
- We save `fx_kwargs` directly onto the CompiledFXGraph. `fx_kwargs` is already, by definition, part of the cache key, so this is safe to do when it comes to cache correctness.
- ^ Why do we do above even though FXGraphCache.load takes fx_kwargs as an argument? Because AOTAutogradCache **doesn't** have access to fx_kwargs: they're annoyingly encoded in the functools.partial() of the fw_compiler, so *only* inductor knows about these options. They're fully captured by the AOTAutogradCache key (since every key to fx_kwargs is either a global config, or a field that's deterministic based on an input graph module), but their values are still needed to run cudagraphs/postprocessing. Therefore, it's easier/safer to store it on the cached result.
- Willing to hear other approaches here if we think saving these extra fields is not reasonable, though I can't think of another way to do this that's less complicated to explain.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130572
Approved by: https://github.com/eellison