Commit Graph

57 Commits

Author SHA1 Message Date
Simon Fan
f889dea97d [internal] Expose additional metadata to compilation callbacks (#153596)
These hooks are used by internal stuck job detection to associate compilation events with the compile lease. Previously, we only had events for Dynamo and Inductor compilation. And recently, the callback handler was updated to ignore nested events. So the Inductor event was only really used by lazy backward.

Here, I remove the inductor event, and add an explicit lazy backward one. Additionally, I add other runtime compilation events: autotuning and cudagraphs. I also expose the CompileId as a string to avoid imports, this will let internal UIs track each graph's contribution to the timeout.

```python
class CallbackTrigger(enum.Enum):
    # most common case, dynamo attempts to trace a new frame
    DYNAMO = 1
    # backward compilation can be deferred to runtime
    LAZY_BACKWARD = 2
    # some backends autotune at runtime
    TRITON_AUTOTUNING = 3
    # cudagraphs record at runtime
    CUDAGRAPH_RECORDING = 4
```

Differential Revision: [D75092426](https://our.internmc.facebook.com/intern/diff/D75092426)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153596
Approved by: https://github.com/masnesral
2025-05-30 08:07:04 +00:00
James Wu
dda2c7c8fc Pass inductor config for static cuda launcher to workers (#153382)
Async compile workers don't respect inductor configs generally that get changed in the middle of execution because they warm up early. StaticCudaLauncher is especially susceptible to this because it affects triton compilation without being part of the inductor meta. So we'll pass it in via extra configs on each worker run.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153382
Approved by: https://github.com/masnesral, https://github.com/jansel
2025-05-14 20:01:32 +00:00
Jovian Anthony Jaison
5d36485b4a Log aot and idx waitcounters. (#152444)
Summary:
Added for create_aot_dispatcher_function and compile_fx_inner.

Note:
Log wait counters flag is already set for:
1. async_compile.precompile
2. remote_fx_graph_cache_get
3. remote_fx_graph_cache_put

Test Plan: contbuild

Differential Revision: D73866124

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152444
Approved by: https://github.com/ppanchalia, https://github.com/masnesral
2025-05-06 16:21:58 +00:00
PyTorch MergeBot
172a7c942e Revert "Log aot and idx waitcounters. (#152444)"
This reverts commit ea9ea02959.

Reverted https://github.com/pytorch/pytorch/pull/152444 on behalf of https://github.com/jovianjaison due to needs a fix ([comment](https://github.com/pytorch/pytorch/pull/152444#issuecomment-2851905261))
2025-05-05 18:11:37 +00:00
Jovian Anthony Jaison
ea9ea02959 Log aot and idx waitcounters. (#152444)
Summary:
Added for create_aot_dispatcher_function and compile_fx_inner.

Note:
Log wait counters flag is already set for:
1. async_compile.precompile
2. remote_fx_graph_cache_get
3. remote_fx_graph_cache_put

Test Plan: contbuild

Differential Revision: D73866124

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152444
Approved by: https://github.com/ppanchalia, https://github.com/masnesral
2025-05-05 17:35:29 +00:00
James Wu
93d8f6ee32 [reland] Detailed triton kernel logging (#152694)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152694
Approved by: https://github.com/Skylion007
2025-05-05 02:46:57 +00:00
PyTorch MergeBot
fecaa60c3c Revert "Add detailed triton kernel logging to tlparse (#152197)"
This reverts commit 8303860de7.

Reverted https://github.com/pytorch/pytorch/pull/152197 on behalf of https://github.com/wdvr due to failing     python test/dynamo/test_structured_trace.py StructuredTraceTest.test_cudagraphs on trunk ([comment](https://github.com/pytorch/pytorch/pull/152197#issuecomment-2840400839))
2025-04-29 22:47:48 +00:00
James Wu
8303860de7 Add detailed triton kernel logging to tlparse (#152197)
This PR adds detailed logging of each triton kernel we compile, and its autotune result, to every kernel we compile with triton. We add these results to a global variable that we then clear after each triton kernel compile.

We can't keep these objects around after compile time, so we can't record the autotune cache save or coordinate descent tuning, unfortunately, but we can log at least:
- The duration of compilation
- Whether or not autotune cache hit
- The best autotuning config, if there's only one.

Example triton kernel info: https://gist.github.com/jamesjwu/493bdd0f36b0b7e3ca327f87bd6c2c75

See internal diff for an example log for internal model.

Differential Revision: [D73674443](https://our.internmc.facebook.com/intern/diff/D73674443)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152197
Approved by: https://github.com/oulgen, https://github.com/eellison
2025-04-29 18:16:56 +00:00
James Wu
0dae27d75b Turn on static cuda launcher in OSS (#151691)
After a few small bugfixes on tests (to make it so we throw/catch similar exceptions to triton), I think we're ready to flip the switch and use StaticCudaLauncher on by default in OSS.

Initial round of benchmarks look good, with average compilation time going down by a few percent:
<img width="828" alt="image" src="https://github.com/user-attachments/assets/cad03e09-b4d6-49a7-a9e5-6068d1c0bd5c" />

With no changes to runtime perf:
<img width="823" alt="image" src="https://github.com/user-attachments/assets/3fcd435e-1057-43f4-878b-8d66a3812a10" />

There are a few noisy models I want to double check, though, so will run some more tests before accepting review.

Full benchmark results, showing a ~5% compile time improvement across the board:
https://hud.pytorch.org/benchmark/huggingface/inductor_with_cudagraphs?dashboard=torchinductor&startTime=Wed%2C%2016%20Apr%202025%2002%3A31%3A12%20GMT&stopTime=Wed%2C%2023%20Apr%202025%2002%3A31%3A12%20GMT&granularity=hour&mode=training&dtype=amp&deviceName=cuda%20(a100)&lBranch=gh/jamesjwu/139/orig&lCommit=cc45c8667fa23dec16ca50002d9504a34688ca5c&rBranch=main&rCommit=2a9afdae81d0dde98e96d7e3c9ca840e241e5405
<img width="1482" alt="image" src="https://github.com/user-attachments/assets/6e6a7f39-7f44-459f-9845-9a37f084ea82" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151691
Approved by: https://github.com/oulgen, https://github.com/jansel, https://github.com/EikanWang
2025-04-25 17:48:53 +00:00
James Wu
cbc0964636 Store statically launchable CachingAutotuners inside CompiledFXGraph.triton_bundle (#149054)
This PR adds CachingAutotuners that are statically launchable to FXGraphCache's cache entry.

Regular CachingAutotuners, with triton kernels attached to them, are not very good to cache: they are very large, and take huge amounts of space since they track all of the various binary files, along with various metadata. We could probably figure out what information we could delete from the kernel and have it still work, but with StaticCudaLauncher, we no longer have to. Instead, we can cache every compiled triton kernel that is statically launchable.

Because StaticTritonCompileResult is serializable, and designed to have a very small memory footprint, we can save it into FXGraphCache without increasing the cache size significantly. We store it as a part of CompiledFxGraph.triton_bundle.

Then, on load, we repopulate the CachingAutotuner into our CompiledTritonKernel cache.

The upsides of this are many:
- We no longer need to call into a separate process on cache hit
- We can *guarantee* that the triton kernel we got from our cache entry is the one we use to launch again, so no worries about triton's own caching logic
- Once we achieve feature parity and all torch.compiled triton kernels are statically launchable, we can clean up a bunch of TritonBundler code and simplify the cache hit logic.

Fixes #149449

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149054
Approved by: https://github.com/oulgen
2025-03-30 17:51:11 +00:00
PyTorch MergeBot
7c4e49750e Revert "Store statically launchable CachingAutotuners inside CompiledFXGraph.triton_bundle (#149054)"
This reverts commit c16af5d798.

Reverted https://github.com/pytorch/pytorch/pull/149054 on behalf of https://github.com/jamesjwu due to Sorry I forgot to fix one last test ([comment](https://github.com/pytorch/pytorch/pull/149054#issuecomment-2761381443))
2025-03-28 13:35:07 +00:00
James Wu
c16af5d798 Store statically launchable CachingAutotuners inside CompiledFXGraph.triton_bundle (#149054)
This PR adds CachingAutotuners that are statically launchable to FXGraphCache's cache entry.

Regular CachingAutotuners, with triton kernels attached to them, are not very good to cache: they are very large, and take huge amounts of space since they track all of the various binary files, along with various metadata. We could probably figure out what information we could delete from the kernel and have it still work, but with StaticCudaLauncher, we no longer have to. Instead, we can cache every compiled triton kernel that is statically launchable.

Because StaticTritonCompileResult is serializable, and designed to have a very small memory footprint, we can save it into FXGraphCache without increasing the cache size significantly. We store it as a part of CompiledFxGraph.triton_bundle.

Then, on load, we repopulate the CachingAutotuner into our CompiledTritonKernel cache.

The upsides of this are many:
- We no longer need to call into a separate process on cache hit
- We can *guarantee* that the triton kernel we got from our cache entry is the one we use to launch again, so no worries about triton's own caching logic
- Once we achieve feature parity and all torch.compiled triton kernels are statically launchable, we can clean up a bunch of TritonBundler code and simplify the cache hit logic.

Fixes #149449

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149054
Approved by: https://github.com/oulgen
2025-03-28 13:28:05 +00:00
PyTorch MergeBot
80aa88f907 Revert "Store statically launchable CachingAutotuners inside CompiledFXGraph.triton_bundle (#149054)"
This reverts commit ac91f8765b.

Reverted https://github.com/pytorch/pytorch/pull/149054 on behalf of https://github.com/yangw-dev due to This is breaking ROCM tests on trunk. hud.pytorch.org/ ([comment](https://github.com/pytorch/pytorch/pull/149054#issuecomment-2759604301))
2025-03-27 22:15:40 +00:00
James Wu
ac91f8765b Store statically launchable CachingAutotuners inside CompiledFXGraph.triton_bundle (#149054)
This PR adds CachingAutotuners that are statically launchable to FXGraphCache's cache entry.

Regular CachingAutotuners, with triton kernels attached to them, are not very good to cache: they are very large, and take huge amounts of space since they track all of the various binary files, along with various metadata. We could probably figure out what information we could delete from the kernel and have it still work, but with StaticCudaLauncher, we no longer have to. Instead, we can cache every compiled triton kernel that is statically launchable.

Because StaticTritonCompileResult is serializable, and designed to have a very small memory footprint, we can save it into FXGraphCache without increasing the cache size significantly. We store it as a part of CompiledFxGraph.triton_bundle.

Then, on load, we repopulate the CachingAutotuner into our CompiledTritonKernel cache.

The upsides of this are many:
- We no longer need to call into a separate process on cache hit
- We can *guarantee* that the triton kernel we got from our cache entry is the one we use to launch again, so no worries about triton's own caching logic
- Once we achieve feature parity and all torch.compiled triton kernels are statically launchable, we can clean up a bunch of TritonBundler code and simplify the cache hit logic.

Fixes #149449

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149054
Approved by: https://github.com/oulgen
ghstack dependencies: #149657
2025-03-27 17:14:44 +00:00
Sam Larsen
c83c711da8 Remove some memory overhead in parallel compile workers (#149168)
Summary: The parallel compile workers are holding on to more memory than they need to because they're loading the compiled modules into memory. Update the post-fork initializer to record when in a subprocess and skip some of the unnecessary overhead.

Test Plan: Ran a test script to compile 15k Triton kernels and used tracemalloc in the subprocs to investigate the overhead. On my devgpu:
* After importing torch in a subproc: 371M
* Without this PR, after compiling 15k kernels: 825M
* With this PR, after compiling 15k kernels: 531M

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149168
Approved by: https://github.com/jansel
2025-03-15 14:20:40 +00:00
Sam Larsen
7cdbb913e7 [logging] Set compile_id in the CachingAutotuner during compilation so we have it for dynamo_timed logging (#148693)
Summary: This is a simpler alternative to https://github.com/pytorch/pytorch/pull/146455, where we can stick the compileId (and forward/backward bool) in the CachingAutotuner so that we have it for logging `benchmark_all_configs`. Recall that the first attempt put the compileId in the inductor_meta and that interfered with caching.

Test Plan:
`python benchmarks/dynamo/torchbench.py --performance --training --amp --backend inductor --device cuda --print-compilation-time --repeat 5 --cold-start-latency --only nanogpt`
* tlparse: https://fburl.com/e71yn6uc
* dynamo_compile: https://fburl.com/scuba/dynamo_compile/sandbox/4ageghhv
* pt2_compile_events: https://fburl.com/scuba/pt2_compile_events/4fgv1itq

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148693
Approved by: https://github.com/eellison
2025-03-13 03:50:58 +00:00
PyTorch MergeBot
b54cf1a281 Revert "[logging] Set compile_id in the CachingAutotuner during compilation so we have it for dynamo_timed logging (#148693)"
This reverts commit 73c8068cf8.

Reverted https://github.com/pytorch/pytorch/pull/148693 on behalf of https://github.com/ZainRizvi due to This is breaking lint on trunk. Please rebase these changes before merging them back in. [GH job link](https://github.com/pytorch/pytorch/actions/runs/13796723235/job/38590020554) [HUD commit link](73c8068cf8) ([comment](https://github.com/pytorch/pytorch/pull/148693#issuecomment-2715671875))
2025-03-11 20:50:23 +00:00
Sam Larsen
73c8068cf8 [logging] Set compile_id in the CachingAutotuner during compilation so we have it for dynamo_timed logging (#148693)
Summary: This is a simpler alternative to https://github.com/pytorch/pytorch/pull/146455, where we can stick the compileId (and forward/backward bool) in the CachingAutotuner so that we have it for logging `benchmark_all_configs`. Recall that the first attempt put the compileId in the inductor_meta and that interfered with caching.

Test Plan:
`python benchmarks/dynamo/torchbench.py --performance --training --amp --backend inductor --device cuda --print-compilation-time --repeat 5 --cold-start-latency --only nanogpt`
* tlparse: https://fburl.com/e71yn6uc
* dynamo_compile: https://fburl.com/scuba/dynamo_compile/sandbox/4ageghhv
* pt2_compile_events: https://fburl.com/scuba/pt2_compile_events/4fgv1itq

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148693
Approved by: https://github.com/eellison
2025-03-11 19:38:40 +00:00
James Wu
8728d4b815 Clear triton kernels after parent make_launcher (#148604)
Before, we were clearing the cache only after inductor compile. But inductor may not **always** compile, i.e. on AOTAutogradCache hit.

So instead, we should clear it when the future is consumed. This is a more robust fix for the issue in D69476856

Differential Revision: [D70646281](https://our.internmc.facebook.com/intern/diff/D70646281/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148604
Approved by: https://github.com/masnesral
2025-03-06 03:28:38 +00:00
Sam Larsen
40c2505f16 [logging] Log individual Triton kernel compilation times to dynamo_compile (#147022)
Summary: Gather the compilation time of individual triton kernels and log them to dynamo_compile:
* Time compilation in `_worker_compile_triton` and pass back to the main process and logged from `get_result()`.
* Added a way to track the "top N" (or N most-expensive compiles) in the metrics_context. I did this because I doubt we really care to capture potentially thousands of kernel compile times. That would be problematic for scuba logging anyway, so let's limit the number we track from the beginning. Arbitrarily chose 25 for now.
* Format the list of compile times as a json string before logging.

Test Plan:
`python benchmarks/dynamo/torchbench.py --performance --training --amp --backend inductor --device cuda --print-compilation-time --repeat 5 --cold-start-latency --only nanogpt`
Scuba: https://fburl.com/scuba/dynamo_compile/sandbox/nc4dzm3r

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147022
Approved by: https://github.com/jamesjwu
2025-03-03 19:32:17 +00:00
Xuehai Pan
1cb4e2df65 [BE][PYFMT] migrate PYFMT for torch._inductor to ruff format (#144550)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144550
Approved by: https://github.com/jansel
2025-02-28 13:33:19 +00:00
James Wu
23524699d5 Only call triton in worker process, kick off worker processes earlier, during inductor codegen (#146417)
### Big idea
This PR extends https://github.com/pytorch/pytorch/pull/144288 by combining calling triton in worker processes with the future cache: we kick off triton compilation in the worker processes earlier, during inductor codegen. Basically instead of calling async_compile.triton for the first time only after the entire code has been generated, we start compiling as soon as we know we'll need to compile the kernel. Then, when loading the generated inductor code, we can simply read from our in memory future cache, considerably increasing the parallelism.
### Implementation Overview
In total, the diff does the following:
- Converts TritonFuture to LambdaFuture, only calling triton.compile on worker processes
- Now that triton.compile() isn't called on the main process, we call TritonBundler on all compiled kernels when we get them back from workers
- Extend @eellison's future cache to a class, mostly as a refactor
- Finally, call async_compile.triton ahead of time in Scheduler.codegen if workers are warmed up. This causes the subsequent
async_compile.triton call that occurs after codegen to cache hit on cold start.
In the diffs after this, I will add more to CompiledTritonKernels so that TritonBundler, on a warm start, automatically populates the in memory cache on warm start with the existing triton kernels, avoiding calling triton altogether on warm starts.
Because LambdaFutures are much faster to kick off than TritonFutures, due to not needing to load from TritonCodeCache at all, the time spent kicking off these worker jobs is pretty minimal for inductor codegen.

Differential Revision: [D69123174](https://our.internmc.facebook.com/intern/diff/D69123174/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146417
Approved by: https://github.com/jansel
2025-02-11 03:46:16 +00:00
eellison
8e258e2ecd Parallelize epilogue/prologue benchmarking (#143408)
When we attempt prologue or epilogue fusion with a TritonTemplate, we benchmark it at compile time in order to determine profitability. This avoids slowdowns/register spilling, and allows us to pick fusion when a base triton template is slower than cublas but faster when considering an epilogue. However, that fused benchmarking does not do the same async compilation as we do for the base TritonTemplate. The Base TritonTemplate is async compiled during lowering, then later waited on and benchmarked.

This PR extends a similar process to benchmarking fused TritonTemplates in the scheduler. We keep a list of pending fusions which have async compilations. And we resolve any pending fusions a node is in prior to attempting to fuse it with any other node.

Initially, I saw some slowdowns with this because we kick off async compilations of identical fusions in parallel. To address this I added source code caching at the `async_compile` level (we also already cache benchmark runs, but that would not happen in parallel).

Compilation speedups:

<img width="717" alt="image" src="https://github.com/user-attachments/assets/8e8f7d6c-7824-4210-83f9-a2a0f6db5ac9" />

This also should let us be a bit more aggressive with either configs, or benchmarking other fusions which are hard to determine profitability of.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143408
Approved by: https://github.com/jansel, https://github.com/shunting314
2025-01-28 18:18:24 +00:00
Aaron Orenstein
893ca1dfe1 PEP585 update - torch/_inductor/[_-i]* (#145137)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145137
Approved by: https://github.com/bobrenjc93
2025-01-19 01:22:47 +00:00
Sam Larsen
b801210035 Restore support for other types of async_compile pools (spawn, fork) (#144491)
Summary: https://github.com/pytorch/pytorch/pull/142001 removed support for process pools other than "subprocess", but some OSS users still find it useful; put it back.

Test Plan: New unit test

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144491
Approved by: https://github.com/jansel, https://github.com/haifeng-jin
2025-01-15 06:04:49 +00:00
Colin L. Rice
84443bd61a feature_use: Remove JK from naming for feature use. (#143529)
See discussion in https://github.com/pytorch/pytorch/pull/142819 but
TL;DR, since we're loging use but not direct JK reads, it's less
confusing to use the logging

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143529
Approved by: https://github.com/ezyang
2025-01-09 17:58:22 +00:00
eellison
e890d67543 Use process pool for precompilation of triton templates (#142450)
Perf results: https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Tue%2C%2003%20Dec%202024%2022%3A57%3A51%20GMT&stopTime=Tue%2C%2010%20Dec%202024%2022%3A57%3A51%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(a100)&lBranch=gh/eellison/740/head&lCommit=b925256c29ec43e1933e4ede94b16d1f404b595f&rBranch=gh/eellison/740/base&rCommit=a161d6362f7d9db773322d2ce2a3a70aabbecf4b

Training:
<img width="793" alt="image" src="https://github.com/user-attachments/assets/75f5bc0d-8005-4213-ae88-0b94fb187dfc" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142450
Approved by: https://github.com/jansel
2024-12-18 01:48:04 +00:00
Aaron Orenstein
159b7ad8aa Improve async workers to handle forking for async compile (#142072)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142072
Approved by: https://github.com/masnesral
2024-12-16 21:16:42 +00:00
Tom Ritchford
da67a6a7bb [inductor] Replace set by OrderedSet (#138466)
Uses the set_linter from https://github.com/pytorch/pytorch/pull/138454
and considerable manual editing

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138466
Approved by: https://github.com/eellison
2024-12-13 16:08:45 +00:00
Sam Larsen
692b5e75ed [logging] Add triton_compile_time_us column to dynamo_compile (#142068)
Test Plan: See internal diff [D66799565](https://www.internalfb.com/diff/D66799565)

Differential Revision: [D66799565](https://our.internmc.facebook.com/intern/diff/D66799565)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142068
Approved by: https://github.com/c00w
2024-12-06 16:11:57 +00:00
Sam Larsen
5bc09ac5e9 Remove option for fork-based compile pool (#142001)
Summary: This has been set to "subproc" for a while internally and externally, so we can remove and simplify some of the code. Note that there's no pressing need here -- just that since we've had internal outage with the legacy "fork" implementation, it doesn't seem helpful to leave it available. But if people aren't in the mood for this sort of cleanup, I won't be offended to abandon it.

Test Plan: CI

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142001
Approved by: https://github.com/eellison, https://github.com/jansel
2024-12-05 17:02:08 +00:00
Colin L. Rice
86f306b15e _inductor: Add dynamo_timed for async_compile.precompile and turn on (#141920)
waitcounters

This fixes some review comments from https://github.com/pytorch/pytorch/pull/141379
and gives us another dynamo_timed event for local compilation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141920
Approved by: https://github.com/masnesral
2024-12-04 04:03:46 +00:00
Colin L. Rice
cc98a1b599 _inductor: Add WaitCounter for triton.compile calls. (#141379)
_inductor: Add WaitCounter for async_compile.wait calls.

This will start recording how long these async_compile.wait calls take.

Note that we want to just unify dynamo_timed in the long term.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141379
Approved by: https://github.com/oulgen, https://github.com/masnesral
2024-12-03 22:56:04 +00:00
Colin L. Rice
0989871ac9 pytorch/feature: Record if parallel compile is enabled (#141074)
This gets a bit messy, but this appears to be the best spot to make a
true / false decision.

Note that since we're looking at whether or not it's used, if the pool
doesn't warm up within the time it takes for a compile, we will mark the
feature use as false.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141074
Approved by: https://github.com/masnesral
ghstack dependencies: #141059
2024-12-02 19:09:11 +00:00
Sam Larsen
ff17d2b83e [easy][logging] Remove dynamo_timed fwd_only param (#140993)
Summary: It's ignored; remove it

Test Plan: CI

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140993
Approved by: https://github.com/ezyang
2024-11-20 02:31:51 +00:00
eellison
0c7c5d78fa [inductor] add support for TRITON_INTERPRET (#140841)
Was debugging the issue lower in the stack and found this to be helpful / quick enough to add.

Fix for https://github.com/pytorch/pytorch/issues/123956

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140841
Approved by: https://github.com/exclamaforte
2024-11-19 11:24:13 +00:00
Max Podkorytov
ca30704f0b [Inductor][ROCm][CK] Add standalone runner (#139441)
Generate standalone executable to debug and profile CK gemm instances

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139441
Approved by: https://github.com/ColinPeppler
2024-11-07 06:21:27 +00:00
James Wu
f4ee5a243d Add PT2 Compile Events for triton and kernel compilation + load_by_key_path (#139402)
Adds a few more dynamo_timed() to measure triton compilation and load_by_key_path times.

In the case of async compilation with multiple threads, we'll generate a single `kernel_compile` event that occurs when waiting on all the parallel compiles to finish.

In the case where async parallel compilation is disabled (or, compile threads are warming up), we'll generate a `triton_compile` event for each kernel.

The `triton_compile` events is a bit questionable: do we need a row for each triton compile event? It might eat up on our already low retention, so I might just remove that. Will discuss with @slarsen.

Differential Revision: [D65215707](https://our.internmc.facebook.com/intern/diff/D65215707/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139402
Approved by: https://github.com/oulgen
2024-11-04 06:37:18 +00:00
Sam Larsen
06b5330674 [easy] Log subproc pool creation (#138642)
Summary: Request from internal to log subproc pool creation

Test Plan:
```
$ TORCH_LOGS=+torch._inductor.async_compile python ~/add.py
I1022 14:12:41.915000 444394 torch/_inductor/async_compile.py:165] Creating subprocess pool with 32 workers
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138642
Approved by: https://github.com/eellison
2024-10-23 02:41:42 +00:00
Max Podkorytov
52ba40c6f6 [ROCm][AOTI] add CK backend (#135641)
Companion to #134379

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135641
Approved by: https://github.com/ColinPeppler, https://github.com/chenyang78

Co-authored-by: Colin Peppler <colinpeppler@meta.com>
2024-10-07 23:53:58 +00:00
Colin Peppler
42adadf2f2 [aotinductor] enable CUTLASS backend (#134379)
### Context
This PR allows CUTLASS kernels usage in AOTI. It does this by:
* For any CUTLASS kernels that win during autotuning, compile them as a .so & .o
* When creating the final model .so, link all the CUTLASS kernels .o files
* Make sure we codegen things correctly (argument dtypes and specify extern "C" linking for the CUTLASS kernel)

### Example
https://gist.github.com/ColinPeppler/e834fa2255c37e9444b6d540bf7bd04d#file-model-cpp-L548-L549

```
TORCH_LOGS="+output_code" python test/inductor/test_cutlass_backend.py -v -k test_max_autotune_cutlass_backend_regular_mm
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134379
Approved by: https://github.com/tenpercent, https://github.com/chenyang78
2024-10-04 17:32:41 +00:00
Sam Larsen
1028cedf71 [inductor] Enable parallel compile by default in fbcode (#136246)
Summary: Now that we have subprocess parallel compile on by default, we can change the internal compile_threads default to > 1 with a killswitch. Some jankiness so we can avoid evaluating the justknob at import.

Test Plan: Ran codecache tests with JK on, then canaried locally with JK off

Differential Revision: D62913998

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136246
Approved by: https://github.com/eellison
2024-09-24 18:10:01 +00:00
Sam Larsen
bf8d0e3107 [inductor] Enable subprocess parallel compile internally with killswitch (#132467)
Differential Revision: [D60629630](https://our.internmc.facebook.com/intern/diff/D60629630)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132467
Approved by: https://github.com/eellison
2024-09-10 19:05:46 +00:00
Sam Larsen
a2db22e6bb [inductor] Catch BrokenProcessPool and print a more helpful message. (#135120)
Summary: BrokenProcessPool means a parallel-compile subprocess exited, which we never expect. It's likely due to a crash, so print a more meaningful error message and instructions that it's probably easier to debug by turning off parallel compile. Output looks like:
```
...
  File "/data/users/slarsen/pytorch/torch/_inductor/runtime/compile_tasks.py", line 45, in _reload_python_module
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/torchinductor_slarsen/4q/c4qw7xk5lbb7whg5txnk4hwbc7z6kepak3o666tr3d64gcad5r5b.py", line 815, in <module>
    async_compile.wait(globals())
  File "/data/users/slarsen/pytorch/torch/_inductor/async_compile.py", line 265, in wait
    raise RuntimeError(
RuntimeError: A compilation subprocess exited unexpectedly. This is likely due to a crash. To facilitate debugging, you can re-run with TORCHINDUCTOR_COMPILE_THREADS=1 to cause compilation to occur in the main process.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135120
Approved by: https://github.com/Chillee
2024-09-07 16:33:37 +00:00
Sam Larsen
362ecd9817 [inductor] Skip the sub-process pool until it's ready (#133508)
Summary: Torch-compiling a quick script can be a bit slower than it needs to be: even though we initialize the subprocess pool early, it still might not be ready by the time we try to compile the first Triton kernel. Instead, let's use the single-threaded path until the pool has successfully completed a no-op job.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133508
Approved by: https://github.com/Chillee
2024-09-04 03:26:55 +00:00
Xuehai Pan
b6d477fd56 [BE][Easy][16/19] enforce style for empty lines in import segments in torch/_i*/ (#129768)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129768
Approved by: https://github.com/jansel
2024-07-20 16:20:58 +00:00
Sam Larsen
358da54be5 [inductor] Better messaging when triton version is too old (#130403)
Summary:
If triton is available, but we can't import triton.compiler.compiler.triton_key, then we see some annoying behavior:
1) If we don't actually need to compile triton, the subprocess pool will still spew error messages about the import failure; it's unclear to users if this is an actual problem.
2) If we do need to compile triton, we a) see the error messages from above and b) get a vanilla import exception without the helpful "RuntimeError: Cannot find a working triton installation ..."

Test Plan: Ran with and without torch.compile for a) recent version of triton, b) triton 2.2, and c) no triton. In all cases, verified expected output (success or meaningful error message)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130403
Approved by: https://github.com/eellison
2024-07-10 23:45:50 +00:00
Sam Larsen
87d14ad419 [inductor] Fix TORCHINDUCTOR_FORCE_DISABLE_CACHES (#129257)
Summary: See https://github.com/pytorch/pytorch/issues/129159; this option wasn't doing its job for a few reasons. In this PR:
* Fix the with_fresh_cache_if_config() decorator
* Reset the "TORCHINDUCTOR_CACHE_DIR" & "TRITON_CACHE_DIR" env vars in sub-process to support them changing in the parent process

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129257
Approved by: https://github.com/oulgen
2024-06-26 18:34:48 +00:00
Max Podkorytov
79959d707c [Inductor][ROCm] Composable Kernel backend for Inductor (#125453)
This PR adds an alternative backend for Inductor, adding Composable Kernel Universal GEMM instances to the autotune instance selection.

The implementation is heavily influenced by the series of PRs which adds CUTLASS backend (https://github.com/pytorch/pytorch/issues/106991). The main differences are
 (1) customizing compiler for the ROCm platform
 (2) customizing template code generation for Composable Kernel Universal GEMM instances.

We provide config tuning knobs for balancing between instance sources compilation time and finding the best instance.

### Testing
Install the ck library
```
pip install git+https://github.com/rocm/composable_kernel@develop
```
Run the test
```
TORCH_LOGS=+torch._inductor \
pytest --capture=tee-sys test/inductor/test_ck_backend.py
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125453
Approved by: https://github.com/eellison, https://github.com/jansel
2024-06-25 20:54:14 +00:00
PyTorch MergeBot
ad76da6c16 Revert "[inductor] Fix TORCHINDUCTOR_FORCE_DISABLE_CACHES (#129257)"
This reverts commit 7b57ddd38c.

Reverted https://github.com/pytorch/pytorch/pull/129257 on behalf of https://github.com/clee2000 due to one of the PRs in the stack seems to have broken test/distributed/_composable/test_replicate_with_compiler.py::ReplicateTest::test_bucketing_concat_op on distributed https://github.com/pytorch/pytorch/actions/runs/9653941844/job/26627760340 4c1e4c5f30, not tested on this PR due to bad TD ([comment](https://github.com/pytorch/pytorch/pull/129257#issuecomment-2189444171))
2024-06-25 16:48:32 +00:00