pytorch/torch/csrc/inductor/aoti_eager
Joel Schlosser 85467ed063 Fix for AOTI + CUDAGraphs when calling from Python (#148601)
**Background**: I've been comparing performance of torch.compile vs. torch.export + AOTI (specifically, loaded from Python) on the Flux model and found a ~1.4% performance decrease with the latter. The trace shows that CUDAGraphs are not utilized for torch.export + AOTI, leading to higher overhead.

When trying to manually CUDAGraph the loaded, previously exported + AOTIed model (thanks to @eellison for the logic here), I get:
```
Error: operation not permitted when stream is capturing
```

@desertfire confirms that this is due to multi-threading logic on the AOTI runtime side (in `AOTIModelContainer` / `AOTIModel`) conflicting with the use of CUDAGraphs.

**Fix**: This PR takes the approach of providing an alternate, single-threaded method for running loaded models with the AOTI runtime. Details:
* Python side introduces a new flag to enable this behavior (needs a better name): `torch._inductor.package.load_package(..., run_single_threaded=False)`
    * This flag is passed down to the C++ side's `AOTIModelPackageLoader`, which passes it to the `CreateAOTIModelRunnerFunc` during `AOTIModelContainerRunner` construction.
* C++ side introduces single-threaded alternatives to model running and model container running:
    * `AOTIModelContainer.run_single_threaded()` / `AOTIModel.run_single_threaded()`. The interfaces match those of `run()`, but the synchronization logic has been removed.
    * Introduces `AOTInductorModelContainerRunSingleThreaded` to AOTI's `interface.h`; this is invoked by the `AOTIModelContainerRunner` utility class when `run_single_threaded=true`.

I've verified on both a small repro and my real-world use case that I can manually CUDAGraph a loaded model that was previously exported + AOTIed.

**Future work:**
* Flip default value to `run_single_threaded=True` as Python-side inference doesn't take advantage of the AOTI runtime thread pool
    * There are some BC concerns here - models need to be re-serialized so the .so contains the new `AOTInductorModelContainerRunSingleThreaded` interface func. We can flip the default value and warn (instead of crashing) if the `AOTInductorModelContainerRunSingleThreaded` symbol does not exist.
* Compose with cudagraph trees as opposed to manual cuda graph wrapping

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148601
Approved by: https://github.com/desertfire
2025-03-08 02:44:14 +00:00
..
kernel_holder.cpp Fix for AOTI + CUDAGraphs when calling from Python (#148601) 2025-03-08 02:44:14 +00:00
kernel_holder.h Enable more readability-redundant checks (#143963) 2024-12-30 14:49:33 +00:00
kernel_meta_info.cpp [2/N] Avoid copy in std::get (#141826) 2024-12-02 00:16:48 +00:00
kernel_meta_info.h