Commit Graph

159 Commits

Author SHA1 Message Date
Laith Sakka
26f7ca3972 Unify dynamic shapes APIs naming 2 (expect_true and check) attempt2 (#156518)
Summary:
The functions guard_lt, guard_equals, and guard_leq work similarly to torch.check and expect_true, but they operate on SymPy expressions. Notably, guard_equals applies local replacements before comparison, which might be better extracted into a separate function.

This pull request standardizes naming conventions to match symbolic_shapes.py. Specifically,
-  it introduces size_vars.expect_true and size_vars.check.
- guard_lt becomes check_lt
- guard_leq becomes check_leq
- guard_equals becomes check_equals

I am also seeing a couple of wrong usages !! that i will fix  in the next PR

Test Plan:
OSS and cont

Rollback Plan:

Differential Revision: D77054177

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156518
Approved by: https://github.com/bobrenjc93
2025-06-24 21:01:38 +00:00
Paul Zhang
86996c15dc [Inductor] Allow exhaustive autotuning across all GEMM options (#156610)
Differential Revision: D76843916

Exhaustive autotuning is meant to autotune GEMM configs across the entire search space of possible configs. Some of these configs can cause extremely long compilation times and OOMs, especially with configs of the following nature:
Excessive register spillage
Using much larger amounts of shared memory than available on the hardware
This diff prunes out those configs to make exhaustive autotuning more viable, along with supporting exhaustive autotuning for persistent+tma template and decompose_k. Previously, exhaustive autotuning would hang, now we are able to tune shapes in ~5 minutes. Below is a sample log for autotuning with exhaustive:

```
  AUTOTUNE mm(1152x21504, 21504x1024)
  strides: [21504, 1], [1, 21504]
  dtypes: torch.bfloat16, torch.bfloat16
  mm 0.1167 ms 100.0%
  triton_mm_6270 0.1172 ms 99.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=256, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0
  triton_mm_6522 0.1183 ms 98.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0
  triton_mm_persistent_tma_7482 0.1190 ms 98.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, A_ROW_MAJOR=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, B_ROW_MAJOR=False, EVEN_K=True, GROUP_M=8, NUM_SMS=132, TMA_SIZE=128, USE_FAST_ACCUM=False, num_stages=5, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0
  triton_mm_persistent_tma_7483 0.1195 ms 97.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, A_ROW_MAJOR=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, B_ROW_MAJOR=False, EVEN_K=True, GROUP_M=8, NUM_SMS=132, TMA_SIZE=128, USE_FAST_ACCUM=False, num_stages=5, num_warps=8, num_consumer_groups=0, num_buffers_warp_spec=0
  triton_mm_6523 0.1274 ms 91.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8, num_consumer_groups=0, num_buffers_warp_spec=0
  triton_mm_6267 0.1285 ms 90.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=256, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0
  triton_mm_6519 0.1287 ms 90.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0
  triton_mm_persistent_tma_7480 0.1298 ms 89.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, A_ROW_MAJOR=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, B_ROW_MAJOR=False, EVEN_K=True, GROUP_M=8, NUM_SMS=132, TMA_SIZE=128, USE_FAST_ACCUM=False, num_stages=4, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0
  triton_mm_persistent_tma_7312 0.1302 ms 89.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, A_ROW_MAJOR=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=256, B_ROW_MAJOR=False, EVEN_K=True, GROUP_M=8, NUM_SMS=132, TMA_SIZE=128, USE_FAST_ACCUM=False, num_stages=4, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0
  SingleProcess AUTOTUNE benchmarking takes 298.7185 seconds and 21.2569 seconds precompiling for 2210 choices
  INFO:tritonbench.utils.triton_op:Took 333894.46ms to get benchmark function for pt2_matmul_maxautotune
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156610
Approved by: https://github.com/jansel
2025-06-24 01:42:05 +00:00
Nicolas Macchioni
50940270ae [BE][3/X] Phase out usage of use_max_autotune() (#155849)
See #155847 for context

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155849
Approved by: https://github.com/masnesral
2025-06-17 20:26:29 +00:00
PyTorch MergeBot
503362d019 Revert "Unify dynamic shapes APIs naming 2 (expect_true and check) (#155776)"
This reverts commit 603a54a9b3.

Reverted https://github.com/pytorch/pytorch/pull/155776 on behalf of https://github.com/atalman due to failing internal build ([comment](https://github.com/pytorch/pytorch/pull/155776#issuecomment-2977041192))
2025-06-16 15:13:53 +00:00
Laith Sakka
603a54a9b3 Unify dynamic shapes APIs naming 2 (expect_true and check) (#155776)
The functions guard_lt, guard_equals, and guard_leq work similarly to torch.check and expect_true, but they operate on SymPy expressions. Notably, guard_equals applies local replacements before comparison, which might be better extracted into a separate function.

This pull request standardizes naming conventions to match symbolic_shapes.py. Specifically,
-  it introduces size_vars.expect_true and size_vars.check.
- guard_lt becomes check_lt
- guard_leq becomes check_leq
- guard_equals becomes check_equals

I am also seeing a couple of wrong usages !! that i will fix  in the next PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155776
Approved by: https://github.com/bobrenjc93
ghstack dependencies: #154774
2025-06-14 17:13:53 +00:00
penknife6153
3e38feb05f [inductor] Add configuration control for CUTLASS operation selection. (#155770)
Added a new configuration option `cutlass_enabled_ops` that allows users to control which operations use CUTLASS lowerings. By default, CUTLASS is enabled for all operations (maintaining backward compatibility), but users can now selectively enable it only for specific operations to optimize compilation time.

**Fixes #155718**

## Usage Examples

```bash
# Enable CUTLASS for all operations (default behavior)
export TORCHINDUCTOR_CUTLASS_ENABLED_OPS="ALL"

# Enable CUTLASS only for matrix multiplication operations
export TORCHINDUCTOR_CUTLASS_ENABLED_OPS="mm,addmm"

# Enable CUTLASS only for batch operations
export TORCHINDUCTOR_CUTLASS_ENABLED_OPS="bmm,baddbmm"

# Disable CUTLASS for all operations
export TORCHINDUCTOR_CUTLASS_ENABLED_OPS=""
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155770
Approved by: https://github.com/henrylhtsang
2025-06-14 08:19:54 +00:00
Laith Sakka
f4376cac54 unify symbolic_shapes and sizevars dynamic shapes APIs naming 1 (#154774)
Inductor have a set of APIs that allows performing symbolic evaluations similar to that of symbolic shapes
but it operates on sympy expressions instead of symnodes. Namings are not consistent making them consistent
in this stack.

Step 1 : unify statically_know_true naming! for consistent experience.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154774
Approved by: https://github.com/drisspg, https://github.com/bobrenjc93, https://github.com/eellison
2025-06-12 16:11:55 +00:00
David Berard
c3ecabf059 [inductor][triton pin] add support for new TMA API for mm.py templates (#155723)
Triton 3.4 will remove the experimental TMA APIs: https://github.com/triton-lang/triton/pull/6488

For mm.py templates, this PR adds support for using the new APIs when they are available (and otherwise falls back to the experimental APIs).

For flex_attention, we'll remove TMA support for Triton 3.2 and 3.3 (versions of triton that don't have the new API).

For mm_scaled_grouped.py, https://github.com/pytorch/pytorch/pull/150944 will remove TMA support for Triton 3.2.

Note: we attempted this earlier with https://github.com/pytorch/pytorch/pull/154858, but this broke TMA usage in Triton 3.2.

Differential Revision: [D76444471](https://our.internmc.facebook.com/intern/diff/D76444471)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155723
Approved by: https://github.com/NikhilAPatel
2025-06-12 06:25:47 +00:00
Oguz Ulgen
d1947a8707 Migrate from lru_cache to cache (#155613)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155613
Approved by: https://github.com/ezyang
ghstack dependencies: #155612
2025-06-11 19:44:18 +00:00
PyTorch MergeBot
3b7c5e6fa5 Revert "[inductor][triton pin] TMA shim refactor & mm, mm_scaled_grouped support (#155182)"
This reverts commit b07725a951.

Reverted https://github.com/pytorch/pytorch/pull/155182 on behalf of https://github.com/davidberard98 due to fails on triton 3.2 (internally) ([comment](https://github.com/pytorch/pytorch/pull/155182#issuecomment-2960664845))
2025-06-10 21:53:01 +00:00
David Berard
b07725a951 [inductor][triton pin] TMA shim refactor & mm, mm_scaled_grouped support (#155182)
Follow-up to #154858.

Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API.

First, this refactors the TMA shim to drop args that aren't supported from Triton 3.2 to Triton 3.4: in particular, strides (Triton 3.2 version doesn't accept non-contiguous inputs, so we just infer contiguous strides in Triton 3.4) and element_ty (Triton 3.4 doesn't support this arg, so in Triton 3.2 we just infer it from base_ptr).

Second, this updates mm.py & mm_scaled_grouped.py to use the TMA shim.

Differential Revision: [D76318784](https://our.internmc.facebook.com/intern/diff/D76318784)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155182
Approved by: https://github.com/drisspg
2025-06-10 06:48:42 +00:00
Michael Lazos
40d02eb481 [Cutlass] Allow filtering by fast_accum for scaled_mm (#155195)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155195
Approved by: https://github.com/drisspg
ghstack dependencies: #154829, #154835
2025-06-09 22:46:18 +00:00
Max Podkorytov
1e6a653234 [ROCm][Inductor][CK] Split ck and ck-tile inductor backend(s) (#155294)
... and fix ck-tile instances not being generated due to incorrect caching

### Testing

Added test cases for CKTILE instances

```
pytest test/inductor/test_ck_backend.py -k gemm_backends_CKTILE
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155294
Approved by: https://github.com/coconutruben
2025-06-09 20:40:26 +00:00
Laith Sakka
b0a2ca65ef support more prologue functions in generated templates cache (#154892)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154892
Approved by: https://github.com/jansel, https://github.com/eellison
ghstack dependencies: #154891
2025-06-04 23:45:36 +00:00
drisspg
208965a9d6 Fix unbackend symint error (#154672)
## Summary

Me and @laithsakka  spoke offline about this one, TLDR is that we wanted this
![image](https://github.com/user-attachments/assets/2e537612-3261-4fbe-a6b9-f8ff92ba3c37)

to also be true for Inductor. In that vein we added two new apis to size-vars which is `guard_or_false`, or `guard_or_true`
with the semantics:

guard_or_false, guard_or_true:

Those APIs may add guards, but will never fail with data-dependent errors; They will try to evaluate the expression with the possibility of adding guards, if that fails due to data dependency, instead of hard failing. False or True are returned.

When to use this?

Performance optimizations that warrant a recompilation.

Take the general path and add a runtime check.
```
# Consider this branching.
if x==0:
    return 1
else
    return 10
# To make data dependent friendly, it can be written as the following:
if guard_or_false(x==0):
    return 1
else
  torch.check(x!=0) # runtime check
  return 10
```

However there is still 1 more api to add to make this example work which is the torch.check which works with expressions, I will leave that to the @laithsakka

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154672
Approved by: https://github.com/laithsakka
2025-05-30 07:45:01 +00:00
Joaquin
cb56df55dc [Inductor]Cleanup autotune_fallback_to_aten post-deprecation (#154331)
Fixes #153298

This PR is the 3rd and final step of #147479
All references to autotune_fallback_to_aten have been removed, and the feature is now deprecated.
All calls to should_fallback_to_aten() were also removed, as they were deemed unnecessary.

[henrylhtsang](https://github.com/henrylhtsang)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154331
Approved by: https://github.com/henrylhtsang
2025-05-29 20:29:58 +00:00
Michael Lazos
423fc671e9 [Cutlass] Support float8_e4m3fn GEMM (#153890)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153890
Approved by: https://github.com/drisspg, https://github.com/eellison
2025-05-22 08:37:33 +00:00
Laith Sakka
4bcff4af99 Move prologue_supported_inputs computations to def_kernal (#150869)
This avoid replaying load_input on a cache hit on the generate_code_cache.
the idea is that if a template have prologue_loads_all_inputs = True, it means that
all all inputs are loaded and hence no need to replay

Effect on the current benchmark on a local run on dev server.
18549985383 -> 15072230073
25697270062 -> 20738613297

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150869
Approved by: https://github.com/eellison
2025-05-22 06:24:44 +00:00
Max Podkorytov
7ef2c62fd3 [ROCm][Inductor][CK] Add ck-tile based universal gemm kernels to torch.mm autotune choices (#152341)
This PR adds code generation for CK-tile based universal gemm kernels to the CK backend for Inductor, and adds these kernels to autotune choices.

Unlike legacy-CK based kernels (which are generated by parsing the CK instances from CK library), we generate the set of instances by manually specifying the tuning parameters.

This PR introduces a new template for code generation, and compilation/autotuning is handled by the existing infrastructure.

Points of discussion:

* For simplicity and reduced coupling with CK, the instance filter checks only data type and layout, and doesn't check the alignment requirement - meaning that more instances will be compiled than necessary - while keeping the code generation independent from internal CK logic which checks the alignment validity at runtime
* CK-tile instances are enabled whenever legacy-CK instances are enabled. A config knob could be introduced to differentiate between the instance types if that's needed
* Whether gemm problem size K is ever dynamic, since whenever it's not a compile-time constant, we need to perform a runtime dispatch between several kernels

** Testing **

Use the existing tests in `test/inductor/test_ck_backend.py`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152341
Approved by: https://github.com/chenyang78
2025-05-21 23:59:16 +00:00
Laith Sakka
11c0ffefcd Cache code generation during triton template expansion and enable it for mm_template. (#151773)
In a model, we see ~~ 40% of the time in mm/addmm tuning. The model have 2000 mm,
many of which receives the same input shapes.

with autotune enabled, this become expensive, while we already cache auto tuning results, we
did not used to cache the generation of the python code and the loading for each config that we autotune on.

This diff handles the code generation part (template expansions) a previous diff handled the loading part.
This is expected to save 20% of the model I am working on.

How do we do the caching?
For a given configurations and input layout, the generated code is always the same. One caveat is that
some other information collected during code generation are input dependent (namely depends on inputs
names and symbol names in inputs). and not just layout. !
To handle those we use a record and replay approach, where we record the functions that are called during
code generation that effect those outputs and replay them at a cache hit.

Effect on the current benchmark on a local run on dev server.
mm_loop. 24115830838 -> 18362098019
mm_loop_dynamic 30506097176-> 25697270062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151773
Approved by: https://github.com/eellison
2025-05-21 18:55:41 +00:00
PaulZhang12
63e5d46478 [Inductor] Subgraph support dynamic input expressions (#153754)
Support subgraph choice taking in inputs that have dynamic dimensions. Testing with decomposeK subgraph decomp

Differential Revision: [D74484741](https://our.internmc.facebook.com/intern/diff/D74484741/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153754
Approved by: https://github.com/eellison
2025-05-20 16:07:18 +00:00
PyTorch MergeBot
b15720118a Revert "Cache code generation during triton template expansion and enable it for mm_template. (#151773)"
This reverts commit 9180bb187c.

Reverted https://github.com/pytorch/pytorch/pull/151773 on behalf of https://github.com/malfet due to It broke ROCm, see f9aa3bae8c/1 ([comment](https://github.com/pytorch/pytorch/pull/151773#issuecomment-2892587039))
2025-05-20 00:42:53 +00:00
Laith Sakka
9180bb187c Cache code generation during triton template expansion and enable it for mm_template. (#151773)
In a model, we see ~~ 40% of the time in mm/addmm tuning. The model have 2000 mm,
many of which receives the same input shapes.

with autotune enabled, this become expensive, while we already cache auto tuning results, we
did not used to cache the generation of the python code and the loading for each config that we autotune on.

This diff handles the code generation part (template expansions) a previous diff handled the loading part.
This is expected to save 20% of the model I am working on.

How do we do the caching?
For a given configurations and input layout, the generated code is always the same. One caveat is that
some other information collected during code generation are input dependent (namely depends on inputs
names and symbol names in inputs). and not just layout. !
To handle those we use a record and replay approach, where we record the functions that are called during
code generation that effect those outputs and replay them at a cache hit.

Effect on the current benchmark on a local run on dev server.
mm_loop. 24115830838 -> 18362098019
mm_loop_dynamic 30506097176-> 25697270062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151773
Approved by: https://github.com/eellison
2025-05-19 20:38:04 +00:00
PaulZhang12
dccd19c2ef [Inductor] Construct subgraph with benchmarking args not example_inputs (#153753)
If the inputs to a subgraph has FlexibleLayout, the subgraph does not currently freeze the layouts here. Therefore, the `example_inputs` generated might not be consistent in layout with the `args` based in for benchmarking

Differential Revision: [D74900879](https://our.internmc.facebook.com/intern/diff/D74900879/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153753
Approved by: https://github.com/eellison
2025-05-19 15:58:40 +00:00
eellison
eaf2dee10e don't run triton mm for k<32 (#153550)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153550
Approved by: https://github.com/suo

Co-authored-by: Natalia Gimelshein <ngimel@meta.com>
2025-05-15 02:36:44 +00:00
drisspg
14f8066910 Ensure mxfp8 scaled_mm works w/ max-autotune (#152744)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152744
Approved by: https://github.com/Skylion007
2025-05-06 01:16:57 +00:00
PaulZhang12
84aa0985fb [Inductor] Add decomposeK as an autotuning choice for mm (#150654)
As a result of adding subgraph as a choice to inductor https://github.com/pytorch/pytorch/pull/149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: https://github.com/pytorch/pytorch/pull/150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm
* Enable for Inference and AOTI

Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:

<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />

TorchInductor Benchmark Dashboard:
<img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" />

We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over.

Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150654
Approved by: https://github.com/eellison
2025-05-03 02:23:54 +00:00
PyTorch MergeBot
7c3e679ddd Revert "[Inductor] Add decomposeK as an autotuning choice for mm (#150654)"
This reverts commit fdcfc6a61a.

Reverted https://github.com/pytorch/pytorch/pull/150654 on behalf of https://github.com/wdvr due to Failing ROCM tests: inductor/test_subgraph_choice.py::TestSubgraphChoice::test_subgraph_decompose_k [GH job link](https://github.com/pytorch/pytorch/actions/runs/14786111108/job/41515742446) [HUD commit link](3c54e0c216) ([comment](https://github.com/pytorch/pytorch/pull/150654#issuecomment-2846470409))
2025-05-02 06:31:38 +00:00
PaulZhang12
fdcfc6a61a [Inductor] Add decomposeK as an autotuning choice for mm (#150654)
As a result of adding subgraph as a choice to inductor https://github.com/pytorch/pytorch/pull/149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: https://github.com/pytorch/pytorch/pull/150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm
* Enable for Inference and AOTI

Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:

<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />

TorchInductor Benchmark Dashboard:
<img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" />

We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over.

Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150654
Approved by: https://github.com/eellison
2025-05-01 23:01:30 +00:00
henrylhtsang
f2cc07d202 [cutlass backend] Add addmm dynamic support (#152498)
Differential Revision: [D73893133](https://our.internmc.facebook.com/intern/diff/D73893133/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152498
Approved by: https://github.com/ColinPeppler
2025-05-01 01:40:08 +00:00
henrylhtsang
4ea2e093ca [inductor][BE] Clean up use_mixed_mm and mixed_mm_choice usage inside pytorch (#152071)
Differential Revision: [D73551912](https://our.internmc.facebook.com/intern/diff/D73551912/)

Decided to leave the mixed_mm tests alive.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152071
Approved by: https://github.com/eellison
2025-04-25 17:25:55 +00:00
Paul Zhang
2a2ddff214 [Inductor] Fix consolidating _scaled_mm into mm template TMA error (#150686)
Summary: The previous diff broke a few tests that didn't run on internal or GH CI: T220169086, this fixes that issue. The {% if } block is only supposed to support autotuned parameters (constexpr), and should not be used for locals based on other examples.

Test Plan: buck test 'fbcode//mode/opt' fbcode//caffe2/test/inductor:fp8 -- --exact 'caffe2/test/inductor:fp8 - test_tensorwise_scaling_bfloat16_shape_16,32,32_has_bias_False_use_fast_accum_True_persistent_matmul_True (caffe2.test.inductor.test_fp8.TestFP8Lowering)'

Reviewed By: NikhilAPatel

Differential Revision: D72460516

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150686
Approved by: https://github.com/eellison, https://github.com/NikhilAPatel
2025-04-04 22:49:22 +00:00
PaulZhang12
e62d958f02 [Inductor] Reland Merge Triton ScaledMM as epilogue to MM template #150045 (#150441)
Merges https://github.com/pytorch/pytorch/pull/150438 and https://github.com/pytorch/pytorch/pull/150045. https://github.com/pytorch/pytorch/pull/150045 was already landed, but did not include a change that makes it unable to land internally.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150441
Approved by: https://github.com/clee2000
2025-04-02 17:49:32 +00:00
Nick Riasanovsky
4934a83347 [AMD] [TRITON] [INDUCTOR] Add tl.assume to enable bufferops on AMD (#150373)
Summary: Update the GEMM template to include the necessary `tl.assume` annotations to enable bufferops with AMD.

Test Plan: Tested manually with a simple matmul run with torch.complie(f, mode="max-autotune") the environment variables TRITON_ALWAYS_COMPILE=1 AMDGCN_ENABLE_DUMP=1 AMDGCN_USE_BUFFER_OPS=1.
Inspecting the generated AMDGCN all loads/stores use bufferops.
Note: Since inductor is loading constants for many of the shape values assumes are generally not needed for the stride/shape information, but pid calculations are generally a gap in Triton's inference capability.

Differential Revision: D71922698

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150373
Approved by: https://github.com/eellison
2025-04-01 23:29:39 +00:00
PyTorch MergeBot
f04cf13bdd Revert "Merge Triton ScaledMM as epilogue to MM template (#150045)"
This reverts commit 981048854d.

Reverted https://github.com/pytorch/pytorch/pull/150045 on behalf of https://github.com/PaulZhang12 due to Need to add PR 150415 fixes for internal merge ([comment](https://github.com/pytorch/pytorch/pull/150045#issuecomment-2770252452))
2025-04-01 17:54:28 +00:00
PaulZhang12
981048854d Merge Triton ScaledMM as epilogue to MM template (#150045)
Previously, scaled_mm's (FP8 matmul) Triton lowering for inductor was in a separate template. This PR consolidates that lowering into the mm template, with an added epilogue to deal with multiplying the scales. This paves the way for future scaled variants of BMM, Grouped GEMM in inductor.

Currently, there is still a separate template for TMA+persistent version of scaled_mm. The current mm lowering has a separate template for TMA + Persistent version. Will hopefully consolidate the extra scaled_mm TMA+persistent template when the consolidation for the mm template is done.
TODO: Consolidate TMA+Persistent logic into 1 template and remove separate scaled_mm TMA template

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150045
Approved by: https://github.com/drisspg
2025-03-31 23:20:14 +00:00
Jack Taylor
32299e5f9a Reland "Introduce new template heuristic for triton autotune configs" (#147452)
This change was reverted in https://github.com/pytorch/pytorch/pull/147388 for regressing an internal workload.

I have removed the additional ir.device_type calls in mm_scaled and unpack_mixed_mm.py which could be contributing to the additional compile time.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147452
Approved by: https://github.com/jansel
2025-03-26 15:47:06 +00:00
Andrey Talman
bc88f6faa1 Use TorchVersion for triton version check (#149136)
Followup after https://github.com/pytorch/pytorch/pull/149092#issuecomment-2721990321
To use TorchVersion for triton version parsing

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149136
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-03-18 13:48:46 +00:00
atalman
65d19a5699 Remove runtime dependency on packaging (#149092)
Looks like after https://github.com/pytorch/pytorch/pull/148924
We are seeing this error in nightly test:
https://github.com/pytorch/pytorch/actions/runs/13806023728/job/38616861623

```
  File "/Users/runner/work/_temp/anaconda/envs/test_conda_env/lib/python3.13/site-packages/torch/_inductor/pattern_matcher.py", line 79, in <module>
    from .lowering import fallback_node_due_to_unsupported_type
  File "/Users/runner/work/_temp/anaconda/envs/test_conda_env/lib/python3.13/site-packages/torch/_inductor/lowering.py", line 7024, in <module>
    from . import kernel
  File "/Users/runner/work/_temp/anaconda/envs/test_conda_env/lib/python3.13/site-packages/torch/_inductor/kernel/__init__.py", line 1, in <module>
    from . import mm, mm_common, mm_plus_mm
  File "/Users/runner/work/_temp/anaconda/envs/test_conda_env/lib/python3.13/site-packages/torch/_inductor/kernel/mm.py", line 6, in <module>
    from packaging.version import Version
ModuleNotFoundError: No module named 'packaging'
```

Hence removing runtime dependency on packaging since it may not be installed by default

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149092
Approved by: https://github.com/drisspg, https://github.com/davidberard98
2025-03-13 14:53:13 +00:00
David Berard
9ad64ce795 [triton 3.3] Forward-fix mm template selection logic (#148924)
Follow-up from https://github.com/pytorch/pytorch/pull/148662.

The logic from https://github.com/pytorch/pytorch/pull/148662 is incorrect; what we want is "choose the second template 'AMD-specific template' only if we're on hip AND triton version < 3.3" - negating it, the code should be "choose the cirst template if we're NOT on hip OR triton version >= 3.3".

Tested locally to verify that it fixes the test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148924
Approved by: https://github.com/drisspg, https://github.com/atalman, https://github.com/eellison
2025-03-11 09:05:44 +00:00
Sampsa
9f170d9d13 [Triton 3.3] Remove ROCm specific mm gemm template (#148662)
Fixes: https://github.com/pytorch/pytorch/issues/147121
Since triton 3.3.x fixes the problem

Needs to be handled in none BC breaking way, so we will conditionalise this change on triton version.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148662
Approved by: https://github.com/davidberard98

Co-authored-by: Jack Taylor <108682042+jataylo@users.noreply.github.com>
2025-03-08 01:24:40 +00:00
Rachel Guo
3f069e7679 [mm_logs] enhance the printing for overview info (#148716)
Summary:
previously the dynamo counters does not print the counts information automatically.

explicitly added a log msg to print after lowering for overview info for inductor aten mms

it will look like:

the name is in `{aten_op_name}_{m}_{n}_{k}`
```
torch/_inductor/compile_fx.py:832] [0/0] Overview info of inductor aten mms: (aten.addmm_16_6_16: 1), (name: count), xxx
```

 {F1975874802}

Test Plan:
```
TORCH_LOGS="+inductor" buck2 run -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=h100 @//mode/opt fbcode//caffe2/test/inductor:test_aot_inductor -- -r test_addmm_cuda
```

Differential Revision: D70739912

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148716
Approved by: https://github.com/henrylhtsang
2025-03-07 05:23:49 +00:00
Rachel Guo
679e7d257e [mm_logs] follow up to add count info based on shape for inductor aten.mms (#148623)
Summary:
as title.

when enable `TORCH_LOGS="+inductor"`, you can get logs at the end such as

stats [('calls_captured', 1), ('unique_graphs', 1)]
inductor [('pattern_matcher_count', 2), ('pattern_matcher_nodes', 2), ('benchmarking.TritonBenchmarker.benchmark_gpu', 2), **(('aten_addmm', (16, 6, 16)), 1)**, ('extern_calls', 1), ('async_compile_cache_miss', 1)]
graph_break []

Test Plan: follow up to add proper logging test.

Differential Revision: D70665104

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148623
Approved by: https://github.com/henrylhtsang
2025-03-06 16:20:04 +00:00
henrylhtsang
b020d166f2 stage 1 of depreate silent fallback of tuning gemm (#147798)
Differential Revision: [D70045778](https://our.internmc.facebook.com/intern/diff/D70045778/)

context:
https://github.com/pytorch/pytorch/issues/147479

For the most part, this should not change the behavior.

For int_mm, I also removed
```
    # TODO: Re-enable eager mode implementation once cuBLAS is fixed
    if use_cutlass or use_triton_template(layout, enable_int32=True):
        choices = []
```
because I think it is unwanted.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147798
Approved by: https://github.com/eellison
2025-03-05 05:15:59 +00:00
Rachel Guo
1673bc7610 [mm_logs][ez] dump tuned mm info at lowering stage (#148363)
Summary:
As title. it would be beneficial for judging e2e perf improvement

Easy first step to dump mm info at lowering stage.

e.g.

```
fbsource/fbcode/caffe2/torch/_inductor/kernel/mm.py:525] [0/0] Tuned aten.addmm: m=16, n=6, k=16, layout=FixedLayout('cuda:0', torch.float32, size=[16, 6], stride=[6, 1])
```

Next step:

Dump overview info at `post_grad_graph` stage such as
overall count of `aten.mm` in the graph & visualize to a table structure.

Test Plan: by looking very hard in aot inductor bmm and mm UTs.

Differential Revision: D70507880

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148363
Approved by: https://github.com/henrylhtsang
2025-03-05 02:21:27 +00:00
PyTorch MergeBot
1919e0de9a Revert "stage 1 of depreate silent fallback of tuning gemm (#147798)"
This reverts commit 297c00264e.

Reverted https://github.com/pytorch/pytorch/pull/147798 on behalf of https://github.com/wdvr due to failing internal builds, discussed with author ([comment](https://github.com/pytorch/pytorch/pull/147798#issuecomment-2692390551))
2025-03-01 20:04:23 +00:00
henrylhtsang
297c00264e stage 1 of depreate silent fallback of tuning gemm (#147798)
Differential Revision: [D70045778](https://our.internmc.facebook.com/intern/diff/D70045778/)

context:
https://github.com/pytorch/pytorch/issues/147479

For the most part, this should not change the behavior.

For int_mm, I also removed
```
    # TODO: Re-enable eager mode implementation once cuBLAS is fixed
    if use_cutlass or use_triton_template(layout, enable_int32=True):
        choices = []
```
because I think it is unwanted.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147798
Approved by: https://github.com/eellison
2025-02-28 19:51:55 +00:00
eellison
4b7604ec10 Delete Mixed MM Special Casing (#147151)
Now that torchinductor supports prologue fusion we can delete all the mixed mm code. When I benchmarked int8 weight only mm in the new path compared to int8mm in the old path in the [following benchmark](https://gist.github.com/eellison/46e321709572c11c077d0612cb3492b7) I got a 1.244x geomean speedup comparing Huggingface linear shapes with bias. There's a couple reasons for the speedup:

- prologue fusion is often unprofitable, even for int8 mm. because the current mixed mm benchmarking only compares triton_int8_mm vs (dtype_conversion + cublas), we miss out on scenarios where the triton template is profitable but the prologue fusion is not.
- similarly, we miss out on potential epilogue fusions like bias if we dispatch to the [fallback mixed mm](5006932cbc/torch/_inductor/kernel/mm.py (L750-L751)) that mixed_mm will dispatch to instead of the deferred epilogue tuning in current path.

It's possible some of the speedups would be smaller on larger models where the epilogue might get fused into a following kernel. Nonetheless, even if this is perf neutral it is worth landing for code deduplication.

The one kernel that is a little special and would not fall out of the prologue fusion is the uint4x2_mixed_mm kernel. it's still possible to generate with prologue fusion but not currently exactly as the current [impl](bd370c138a/torch/_inductor/kernel/unpack_mixed_mm.py (L43-L49)). But the current impl does not compare to a cublas baseline so I found that it is making things slower (35% slower on a not particularly big 1024, 1024, 1024 mm shape on h100). this should be fine to delete.

Future optimizations could include:

- cutlass prologue path
- making prologue fusion support the persistent tma based mm template. from @drisspg's experience this led to nice wins with fp8 but not as nice wins with bf16 mm. I think similarly, lower memory bandwidth int8 mm would benefit.

Differential Revision: [D70114858](https://our.internmc.facebook.com/intern/diff/D70114858)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147151
Approved by: https://github.com/drisspg, https://github.com/cpuhrsch
2025-02-25 04:29:54 +00:00
PyTorch MergeBot
3409cbd177 Revert "Delete Mixed MM Special Casing (#147151)"
This reverts commit d6bb1d7f0a.

Reverted https://github.com/pytorch/pytorch/pull/147151 on behalf of https://github.com/jeanschmidt due to Broke a few internal signals, see comments on D69994157 ([comment](https://github.com/pytorch/pytorch/pull/147151#issuecomment-2676312215))
2025-02-22 17:14:32 +00:00
henrylhtsang
76ce194b8e For addmm and bmm, check if config.autotune_fallback_to_aten before using aten as a fallback. Also fix bmm cutlass backend (#147148)
This PR also fixes BMM, which was silently failing for a while.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147148
Approved by: https://github.com/eellison
2025-02-21 18:41:52 +00:00