Commit Graph

521 Commits

Author SHA1 Message Date
Marcin Pioch
e694280d12 Custom FX pass for inductor's backend registration (#154841)
This PR is related to RFC #153532. It is an extension to Inductor's backend registration interface to allow to register custom FX passes by the backend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154841
Approved by: https://github.com/jansel

Co-authored-by: Jason Ansel <jansel@jansel.net>
2025-06-06 06:49:44 +00:00
PyTorch MergeBot
5e03433443 Revert "Inductor logging + analysis of torch.profile (#149697)"
This reverts commit e5afbe3124.

Reverted https://github.com/pytorch/pytorch/pull/149697 on behalf of https://github.com/malfet due to Broke rocm, see 642687af29/1 ([comment](https://github.com/pytorch/pytorch/pull/149697#issuecomment-2942415600))
2025-06-05 01:38:13 +00:00
Gabriel Ferns
e5afbe3124 Inductor logging + analysis of torch.profile (#149697)
Prereqs:
 - https://github.com/pytorch/pytorch/pull/152708

Features:
1. Adds inductor's estimate of flops and bandwidth to the json trace events that perfetto uses.
1. Only use the tflops estimation from triton if we don't have the info from the datasheet because Triton's estimates are inaccurate. I have a backlog item to fix triton flops estimation upstream. New `DeviceInfo` class, and new function `get_device_tflops`.
1. New helpers `countable_fx` and `count_flops_fx` helps get the flops of an `fx.Node`.
1. Extends Triton `torch.profiler` logging to `DebugAutotuner`.
1. New script `profile_analysis.py`: `--augment_trace` adds perf estimates to any perfetto json trace, `--analyze` creates a summary table of these perf estimates, and `--diff` will compare two traces side by side:
```python
Device(NVIDIA H100, 0):
 Kernel Name                              | resnet Kernel Count | resnet FLOPS       | resnet bw gbps        | resnet Dur (ms)    | resnet Achieved FLOPS % | resnet Achieved Bandwidth % | newresnet Kernel Count | newresnet FLOPS    | newresnet bw gbps     | newresnet Dur (ms) | newresnet Achieved FLOPS % | newresnet Achieved Bandwidth %
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
 triton_poi_fused__native_batch_norm_legi | 24                  | 0                  | 0.11395268248131513   | 2.5919166666666666 | 0                       | 0.003401572611382541        | 24                     | 0                  | 0.11395268248131513   | 2.5919166666666666 | 0                          | 0.003401572611382541
 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 142                 | 16932673552.422373 | 0.2585007824198784    | 12.441619718309857 | 0.08683422334575583     | 0.007716441266265022        | 142                    | 16932673552.422373 | 0.2585007824198784    | 12.441619718309857 | 0.08683422334575583        | 0.007716441266265022
 triton_red_fused__native_batch_norm_legi | 39                  | 0                  | 0.13990024992108846   | 5.752589743589743  | 0                       | 0.004176126863316074        | 39                     | 0                  | 0.13990024992108846   | 5.752589743589743  | 0                          | 0.004176126863316074
 triton_poi_fused__native_batch_norm_legi | 25                  | 0                  | 0.31824055917536503   | 2.5291999999999994 | 0                       | 0.009499718184339253        | 25                     | 0                  | 0.31824055917536503   | 2.5291999999999994 | 0                          | 0.009499718184339253
 void cutlass::Kernel2<cutlass_80_tensoro | 98                  | 16211056473.596165 | 0.42972434051025826   | 7.130408163265306  | 0.08313362294151874     | 0.012827592254037562        | 98                     | 16211056473.596165 | 0.42972434051025826   | 7.130408163265306  | 0.08313362294151874        | 0.012827592254037562
 triton_red_fused__native_batch_norm_legi | 73                  | 0                  | 0.3225381327611705    | 9.987068493150682  | 0                       | 0.009628003963020014        | 73                     | 0                  | 0.3225381327611705    | 9.987068493150682  | 0                          | 0.009628003963020014
 triton_poi_fused__native_batch_norm_legi | 15                  | 0                  | 1.4491211346487216    | 4.439333333333333  | 0                       | 0.043257347302946926        | 15                     | 0                  | 1.4491211346487216    | 4.439333333333333  | 0                          | 0.043257347302946926
 void cutlass::Kernel2<cutlass_80_tensoro | 186                 | 14501701145.337954 | 0.2667131401910989    | 7.873865591397849  | 0.07436769818122027     | 0.007961586274361157        | 186                    | 14501701145.337954 | 0.2667131401910989    | 7.873865591397849  | 0.07436769818122027        | 0.007961586274361157
 triton_poi_fused__native_batch_norm_legi | 33                  | 0                  | 1.4924556538193923    | 4.3101515151515155 | 0                       | 0.044550915039384846        | 33                     | 0                  | 1.4924556538193923    | 4.3101515151515155 | 0                          | 0.044550915039384846
 triton_red_fused__native_batch_norm_legi | 29                  | 0                  | 0.25562590522631107   | 6.296275862068965  | 0                       | 0.007630624036606301        | 29                     | 0                  | 0.25562590522631107   | 6.296275862068965  | 0                          | 0.007630624036606301
 triton_poi_fused__native_batch_norm_legi | 13                  | 0                  | 0.5870562174192726    | 2.7397692307692307 | 0                       | 0.01752406619162008         | 13                     | 0                  | 0.5870562174192726    | 2.7397692307692307 | 0                          | 0.01752406619162008
 triton_poi_fused__native_batch_norm_legi | 34                  | 0                  | 0.41409928846284      | 2.853588235294117  | 0                       | 0.012361172789935523        | 34                     | 0                  | 0.41409928846284      | 2.853588235294117  | 0                          | 0.012361172789935523
 triton_per_fused__native_batch_norm_legi | 34                  | 0                  | 0.11705315007018151   | 3.460647058823529  | 0                       | 0.0034941238826919864       | 34                     | 0                  | 0.11705315007018151   | 3.460647058823529  | 0                          | 0.0034941238826919864
 triton_poi_fused__native_batch_norm_legi | 16                  | 0                  | 0.17207853197124584   | 2.3459375000000002 | 0                       | 0.005136672596156592        | 16                     | 0                  | 0.17207853197124584   | 2.3459375000000002 | 0                          | 0.005136672596156592
 triton_per_fused__native_batch_norm_legi | 30                  | 0                  | 0.2639714322022256    | 6.131199999999999  | 0                       | 0.007879744244842555        | 30                     | 0                  | 0.2639714322022256    | 6.131199999999999  | 0                          | 0.007879744244842555
 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 100                 | 11875430356.891787 | 0.19494470869421385   | 16.36534           | 0.06089964285585531     | 0.005819245035648175        | 100                    | 11875430356.891787 | 0.19494470869421385   | 16.36534           | 0.06089964285585531        | 0.005819245035648175
 triton_poi_fused__native_batch_norm_legi | 8                   | 0                  | 0.9854096626224687    | 3.2757500000000004 | 0                       | 0.029415213809625928        | 8                      | 0                  | 0.9854096626224687    | 3.2757500000000004 | 0                          | 0.029415213809625928
 void cublasLt::splitKreduce_kernel<32, 1 | 56                  | 34377923395.147064 | 0.8310300045762317    | 3.4199999999999986 | 0.17629704305203628     | 0.024806865808245714        | 56                     | 34377923395.147064 | 0.8310300045762317    | 3.4199999999999986 | 0.17629704305203628        | 0.024806865808245714
 triton_poi_fused__native_batch_norm_legi | 23                  | 0                  | 0.9944002965861103    | 3.2431304347826084 | 0                       | 0.02968359094286896         | 23                     | 0                  | 0.9944002965861103    | 3.2431304347826084 | 0                          | 0.02968359094286896
 triton_per_fused__native_batch_norm_legi | 10                  | 0                  | 0.1826801058931057    | 4.428800000000001  | 0                       | 0.00545313748934644         | 10                     | 0                  | 0.1826801058931057    | 4.428800000000001  | 0                          | 0.00545313748934644
 triton_poi_fused__native_batch_norm_legi | 10                  | 0                  | 0.3168973585366449    | 2.5471999999999997 | 0                       | 0.009459622642884923        | 10                     | 0                  | 0.3168973585366449    | 2.5471999999999997 | 0                          | 0.009459622642884923
 triton_poi_fused__native_batch_norm_legi | 34                  | 0                  | 1.1463614897015777    | 4.124323529411764  | 0                       | 0.03421974596124114         | 34                     | 0                  | 1.1463614897015777    | 4.124323529411764  | 0                          | 0.03421974596124114
 void cask_plugin_cudnn::xmma_cudnn::init | 44                  | 44045510816.64277  | 2.0661232850348643    | 3.6887499999999993 | 0.22587441444432194     | 0.06167532194133924         | 44                     | 44045510816.64277  | 2.0661232850348643    | 3.6887499999999993 | 0.22587441444432194        | 0.06167532194133924
 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 95                  | 7876855400.165316  | 0.4694941555946739    | 18.224315789473682 | 0.04039413025725802     | 0.014014750913273854        | 95                     | 7876855400.165316  | 0.4694941555946739    | 18.224315789473682 | 0.04039413025725802        | 0.014014750913273854
 triton_per_fused__native_batch_norm_legi | 41                  | 0                  | 0.06825669875995298   | 3.0384146341463416 | 0                       | 0.002037513395819492        | 41                     | 0                  | 0.06825669875995298   | 3.0384146341463416 | 0                          | 0.002037513395819492
 triton_poi_fused__native_batch_norm_legi | 23                  | 0                  | 0.08808154712430301   | 2.3275652173913044 | 0                       | 0.0026292999141582997       | 23                     | 0                  | 0.08808154712430301   | 2.3275652173913044 | 0                          | 0.0026292999141582997
 triton_per_fused__native_batch_norm_legi | 40                  | 0                  | 0.18179321034952417   | 4.556825           | 0                       | 0.005426662995508183        | 40                     | 0                  | 0.18179321034952417   | 4.556825           | 0                          | 0.005426662995508183
 triton_poi_fused__native_batch_norm_legi | 15                  | 0                  | 0.5887415155454232    | 2.783866666666667  | 0                       | 0.017574373598370836        | 15                     | 0                  | 0.5887415155454232    | 2.783866666666667  | 0                          | 0.017574373598370836
 void cutlass::Kernel2<cutlass_80_tensoro | 38                  | 14242013806.264643 | 0.256592404353939     | 7.217631578947369  | 0.0730359682372546      | 0.007659474756834           | 38                     | 14242013806.264643 | 0.256592404353939     | 7.217631578947369  | 0.0730359682372546         | 0.007659474756834
 triton_poi_fused__native_batch_norm_legi | 21                  | 0                  | 0.5842860973430516    | 2.7779047619047623 | 0                       | 0.017441376040091088        | 21                     | 0                  | 0.5842860973430516    | 2.7779047619047623 | 0                          | 0.017441376040091088
 triton_per_fused__native_batch_norm_legi | 16                  | 0                  | 0.11509365173486417   | 3.5959375000000002 | 0                       | 0.0034356313950705724       | 16                     | 0                  | 0.11509365173486417   | 3.5959375000000002 | 0                          | 0.0034356313950705724
 triton_poi_fused__native_batch_norm_legi | 14                  | 0                  | 0.1704672000243914    | 2.4044285714285714 | 0                       | 0.00508857313505646         | 14                     | 0                  | 0.1704672000243914    | 2.4044285714285714 | 0                          | 0.00508857313505646
 triton_poi_fused__native_batch_norm_legi | 58                  | 0                  | 2.307520779930795     | 8.190706896551722  | 0                       | 0.06888121731136704         | 58                     | 0                  | 2.307520779930795     | 8.190706896551722  | 0                          | 0.06888121731136704
 triton_per_fused__native_batch_norm_legi | 29                  | 0                  | 0.037243248971881276  | 3.0277586206896556 | 0                       | 0.001111738775280038        | 29                     | 0                  | 0.037243248971881276  | 3.0277586206896556 | 0                          | 0.001111738775280038
 triton_poi_fused__native_batch_norm_legi | 20                  | 0                  | 0.04741699795428918   | 2.2911500000000005 | 0                       | 0.0014154327747549007       | 20                     | 0                  | 0.04741699795428918   | 2.2911500000000005 | 0                          | 0.0014154327747549007
 triton_per_fused__native_batch_norm_legi | 25                  | 0                  | 0.13357016893727824   | 3.37536            | 0                       | 0.003987169222008305        | 25                     | 0                  | 0.13357016893727824   | 3.37536            | 0                          | 0.003987169222008305
 triton_poi_fused__native_batch_norm_legi | 13                  | 0                  | 0.3089862268300253    | 2.8111538461538457 | 0                       | 0.009223469457612694        | 13                     | 0                  | 0.3089862268300253    | 2.8111538461538457 | 0                          | 0.009223469457612694
 triton_poi_fused__native_batch_norm_legi | 17                  | 0                  | 0.3129385387909844    | 2.673              | 0                       | 0.009341448919133863        | 17                     | 0                  | 0.3129385387909844    | 2.673              | 0                          | 0.009341448919133863
 triton_per_fused__native_batch_norm_legi | 19                  | 0                  | 0.2215568162533158    | 3.8837368421052636 | 0                       | 0.0066136363060691275       | 19                     | 0                  | 0.2215568162533158    | 3.8837368421052636 | 0                          | 0.0066136363060691275
 std::enable_if<!(false), void>::type int | 23                  | 504916805.19297093 | 1.0118296096314707    | 8.113913043478261  | 0.0025893169497075447   | 0.030203868944223014        | 23                     | 504916805.19297093 | 1.0118296096314707    | 8.113913043478261  | 0.0025893169497075447      | 0.030203868944223014
 triton_poi_fused_add_copy__38            | 56                  | 0                  | 0                     | 2.132482142857143  | 0                       | 0                           | 56                     | 0                  | 0                     | 2.132482142857143  | 0                          | 0
 triton_poi_fused_convolution_0           | 18                  | 0                  | 0.43458610794936897   | 2.773333333333334  | 0                       | 0.012972719640279667        | 18                     | 0                  | 0.43458610794936897   | 2.773333333333334  | 0                          | 0.012972719640279667
 triton_poi_fused_convolution_1           | 17                  | 0                  | 0.028816312469162712  | 2.6145882352941174 | 0                       | 0.0008601884319153051       | 17                     | 0                  | 0.028816312469162712  | 2.6145882352941174 | 0                          | 0.0008601884319153051
 void convolve_common_engine_float_NHWC<f | 44                  | 8641868995.31118   | 0.024730540008465626  | 25.87327272727273  | 0.04431727689903169     | 0.0007382250748795709       | 44                     | 8641868995.31118   | 0.024730540008465626  | 25.87327272727273  | 0.04431727689903169        | 0.0007382250748795709
 triton_per_fused__native_batch_norm_legi | 12                  | 0                  | 0.6809930918986744    | 4.82675            | 0                       | 0.020328151996975356        | 12                     | 0                  | 0.6809930918986744    | 4.82675            | 0                          | 0.020328151996975356
 triton_per_fused__native_batch_norm_legi | 14                  | 0                  | 0.02883030597936608   | 2.6651428571428575 | 0                       | 0.0008606061486377935       | 14                     | 0                  | 0.02883030597936608   | 2.6651428571428575 | 0                          | 0.0008606061486377935
 triton_per_fused__native_batch_norm_legi | 16                  | 0                  | 0.0014658988233201874 | 2.098              | 0                       | 4.375817383045335e-05       | 16                     | 0                  | 0.0014658988233201874 | 2.098              | 0                          | 4.375817383045335e-05
 triton_poi_fused__native_batch_norm_legi | 13                  | 0                  | 0.9926297180284697    | 3.2367692307692306 | 0                       | 0.02963073785159611         | 13                     | 0                  | 0.9926297180284697    | 3.2367692307692306 | 0                          | 0.02963073785159611
 triton_poi_fused__native_batch_norm_legi | 9                   | 0                  | 1.3008817095666507    | 3.0863333333333336 | 0                       | 0.03883228983781048         | 9                      | 0                  | 1.3008817095666507    | 3.0863333333333336 | 0                          | 0.03883228983781048
 void at::native::(anonymous namespace):: | 98                  | 0                  | 0.09174335613709389   | 4.408520408163265  | 0                       | 0.0027386076458833994       | 98                     | 0                  | 0.09174335613709389   | 4.408520408163265  | 0                          | 0.0027386076458833994
 void at::native::vectorized_elementwise_ | 7                   | 0                  | 0                     | 1.7278571428571428 | 0                       | 0                           | 7                      | 0                  | 0                     | 1.7278571428571428 | 0                          | 0
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149697
Approved by: https://github.com/eellison, https://github.com/shunting314
2025-06-04 20:03:46 +00:00
Boyuan Feng
a4da1d4a47 [Graph Partition] support standalone_compile (#154698)
For graph partition, `write_get_raw_stream_header_once` is done once so the autotune code may not have the header. This PR additionally calls `write_get_raw_stream_header` in `codegen_device_guard_enter` before `get_raw_stream` is used.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154698
Approved by: https://github.com/oulgen
2025-06-03 07:40:42 +00:00
Paul Zhang
0c6c7780d9 [Inductor] Add envvar to disable decomposeK (#154421)
Summary: Add envvar to Inductor config to disable decomposeK autotuning choice

Test Plan: `buck test 'fbcode//mode/opt' fbcode//caffe2/test/inductor:max_autotune -- --exact 'caffe2/test/inductor:max_autotune - test_max_autotune_decompose_k_dynamic_False_sizes2 (caffe2.test.inductor.test_max_autotune.TestMaxAutotune)' --run-disabled`

Reviewed By: eellison

Differential Revision: D75174823

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154421
Approved by: https://github.com/eellison
2025-05-29 23:34:41 +00:00
eellison
d6e29bf875 Reflect back mutation if we clone misaligned tensors (#154442)
Fix for https://github.com/pytorch/pytorch/issues/152425

inductor specializes whether or not a tensor is 16-bit aligned on the first invocation. then, on subsequent invocations, if we inferred alignment but are passed a non-aligned tensor we clone the tensor.

If we infer alignment, then run with unaligned, and mutate the input, we need to reflect back the mutation to the input. This pr adds back that mutation.

We could have also been less aggressive about inferring alignment for mutated tensors, but that has a pretty perf hit.See the following benchmark:
```
import torch

t = torch.rand(4096 * 4096, device="cuda", dtype=torch.float16)

@torch.compile(dynamic=False)
def foo(x):
    return x.add_(1)

import triton

print(triton.testing.do_bench(lambda: foo(t[:-1])))
torch._dynamo.reset()
print(triton.testing.do_bench(lambda: foo(t[1:])))
```
gives
```
0.04063070610165596
0.07613472988113162
```
So almost twice as slow for non-aligned tensors. Tensors changing alignment is a relatively rare case.

In the future, we could considering a multi-kernel approach, or codegening a triton kernel that does most of the loads with aligned instructions, and a prologue/epilogue of un-alignment. But, it's yet to be seen this is a huge issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154442
Approved by: https://github.com/bobrenjc93, https://github.com/bdhirsh
2025-05-29 13:36:48 +00:00
angelayi
26471fc203 [aoti] Initial Metal support (#153959)
An example generated file: P1816629015

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153959
Approved by: https://github.com/malfet, https://github.com/desertfire
ghstack dependencies: #153964
2025-05-23 05:45:35 +00:00
PyTorch MergeBot
47a01f3efb Revert "[aoti] Initial Metal support (#153959)"
This reverts commit 28bcd9eb30.

Reverted https://github.com/pytorch/pytorch/pull/153959 on behalf of https://github.com/angelayi due to previous PR broke frl build ([comment](https://github.com/pytorch/pytorch/pull/153959#issuecomment-2901825315))
2025-05-22 16:17:07 +00:00
angelayi
28bcd9eb30 [aoti] Initial Metal support (#153959)
An example generated file: P1816629015

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153959
Approved by: https://github.com/malfet, https://github.com/desertfire
ghstack dependencies: #153964
2025-05-21 21:55:59 +00:00
PyTorch MergeBot
01bb249978 Revert "has_triton: Use the device interface for detecting Triton availability (#139171)"
This reverts commit 48bfe9afc7.

Reverted https://github.com/pytorch/pytorch/pull/139171 on behalf of https://github.com/masnesral due to Performance regression for huggingface ([comment](https://github.com/pytorch/pytorch/pull/139171#issuecomment-2868939790))
2025-05-10 14:46:23 +00:00
Menglu Yu
2d25e4d478 [1/n][Optimus][Auto-AC] Support activation quantization without scaling (#148380)
Summary: We enable the activation quantization in the forward pass, and users can customize the dtype they want to quantize.

Test Plan:
# unit test

```
buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/inductor:quantization -- test_activation_quantization_aten
```

Buck UI: https://www.internalfb.com/buck2/776d3911-bb86-4ac8-a527-540cf1510b9d
Test UI: https://www.internalfb.com/intern/testinfra/testrun/4785074873051017
Network: Up: 4.3MiB  Down: 42MiB  (reSessionID-fef7e727-68b1-4645-a519-5652854df38d)
Executing actions. Remaining     0/4                                                                                 6.7s exec time total
Command: test.     Finished 2 local
Time elapsed: 3:11.5s
Tests finished: Pass 2. Fail 0. Fatal 0. Skip 0. Build failure 0

# E2E

### how to enable (you can overrite the dtype, if nothing given, the default is fp8)

```
post_grad_fusion_options={
            "activation_quantization_aten_pass": {"quant_type": "torch.float8_e5m2"}
        },
```

Differential Revision: D70522237

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148380
Approved by: https://github.com/Mingming-Ding, https://github.com/Hahu803
2025-05-08 04:44:15 +00:00
George White
48bfe9afc7 has_triton: Use the device interface for detecting Triton availability (#139171)
This PR replaces the `has_triton()` global method which was previously used for this task.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139171
Approved by: https://github.com/jansel, https://github.com/shink
2025-05-07 12:23:10 +00:00
Blaine Burton Rister
bc11afd41f [Inductor] FX backend via Wrapper IR (#146942)
# Sub-PRs

These PRs contain refactors from the main one. They should be reviewed and merged first.

- https://github.com/pytorch/pytorch/pull/150458
- https://github.com/pytorch/pytorch/pull/152391
- https://github.com/pytorch/pytorch/pull/152587

# Feature

The goals of this PR are twofold.

## Goal 1: Introduce Wrapper IR as an intermediate step in wrapper codegen.

In addition to Triton/C++/Halide kernels, Inductor also generates "wrapper" code which allocates memory and calls the kernels. Originally, this wrapper code was fairly standard Python which resembled a user-written PyTorch program. Over time, various wrapper code generators have been added to accommodate things like AOTInductor, which prefers C++ code for static compilation. This complexity has bled into other parts of the codebase, as we now need if/else statements to choose between Python and C++ macros. (See an example [here](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/ir.py#L5515-L5522).) Since most of these code generation steps are conceptually identical across target languages, it seems reasonable to refactor them into some kind of intermediate representation which can be shared between the various backends. This might also make it easier to develop out-of-tree backends which cannot put their own macros in core Inductor components.

This PR takes some initial steps to formalize Inductor's wrapper codegen by generalizing the existing Memory Planning IR into a fully fledged Wrapper IR. This is pretty much identical to the existing Memory Planning IR, but it supports a richer set of ops for things like kernel definitions and calls. This refactor could help encapsulate wrapper codegen. Ideally, we don't need to worry about direct Python/C++ codegen in the main compiler files such as `ir.py`, and can instead defer these to classes like `PythonWrapperCodegen` and `CppWrapperCpu`, which operate on the Wrapper IR.

## Goal 2: Convert Wrapper IR into FX IR.

One of the main benefits of Wrapper IR is to enable more diverse Inductor backends. This PR introduces a converter from Wrapper IR into [FX IR](https://pytorch.org/docs/stable/fx.html), which is the intermediate representation most commonly used in PyTorch graph compilers. The purpose of this is to enable out-of-tree backends to consume Inductor's output in FX IR, which would hopefully make Inductor easier to leverage in novel compilers, hardware accelerators, etc.

It's not trivial to generate Python or C++ code which Inductor can compile and run, and doing so may require changes to other core Inductor files, for the reasons outlined in the previous section. The goal of supporting FX output is to enable something like `torch.compile`'s [custom backend](https://pytorch.org/docs/stable/torch.compiler_custom_backends.html) system, in which an out-of-tree backend can receive an optimized FX graph from Inductor, and compile and run it however it likes.

The typical users of this feature would likely not be part of PyTorch, and may or may not support running a kernel in eager mode. However, they can understand what `torch.empty_strided` means, compile and run Triton kernels, etc. So we just need to present them with an FX graph saying what code Inductor wants to run, which should be easier to analyze and transform in a third party system than Python or C++ source.

Since FX IR is fairly stable, this mechanism should hopefully isolate third-party backends, hardware accelerators, etc. from the implementation details of Inductor, and vice versa.

# Current status

Things that seem to work:

- Converted a lot of the most common Python codegen lines to Wrapper IR lines.
     - Handled the following cases, in addition to what was already in the Memory Planning IR:
         - Comments
         - Triton kernels
         - Extern/fallback kernels
         - Freeing tensors (`del buf0`)
         - MultiOutput
         - Graph outputs
         - ReinterpretView / StorageBox, for both call args and outputs.
     - FX conversion asserts that the program only contains Wrapper IR lines, and not strings of Python/C++ code.
- Prototype FX converter which can handle some of the most common use cases.
   - Defining Triton kernels, and putting them in a side table using TorchDynamo's existing [utilities](https://dev-discuss.pytorch.org/t/higher-order-operators-2023-10/1565).
   - Calling wrapped Triton kernels.
   - Calling extern kernels and certain types of fallback kernels.
       - Support both `extern_kernels.*` and `aten.*`.
       - Support multi-output kernels like `torch.topk`.
   - Graphs with multiple inputs/outputs.
   - Training i.e. calling `Tensor.backward()` in a compiled function.
   - Graph breaks (training).
- Run the `torch.fx.GraphModule` on GPU using the standard `__call__` method. This makes it easy to test the correctness of FX codegen.

Things that don't work:
- Both Wrapper IR and Wrapper -> FX coverage are currently best effort. There are still features which aren't captured as Wrapper IR lines, and fall back to plain strings. This representation is functionally correct but probably not rich enough to achieve the goals outlined in the previous sections.
         - Fallback kernels seem like the most difficult thing to fully cover, since they each define their own Python/C++ macros that would need to be converted to FX.
         - Size/alignment asserts are currently disabled via the config file. It's possible to generate FX IR for these, but it seems reasonable to defer these sanity checks to a later PR.
         - CommBuffer's and distributed communication are not yet supported. An earlier version of this PR attempted to implement this by calling `empty_strided_p2p`. However, building and testing distributed support seems non-trivial, so it's probably better to defer this.

# Out-of-tree compilers

With this PR, out of tree backends will be able to do further compilation on the FX graphs by subclassing `WrapperFxCodegen` and overriding the `compile_graph` function. This follows the same API as torch.compile's [custom backends](https://pytorch.org/docs/stable/torch.compiler_custom_backends.html), where the user simply returns a callable running the graph. The callable need not be a method of `GraphModule` or any other PyTorch class. See an example below.

```
from torch._inductor.codegen.wrapper_fxir import WrapperFxCodegen

class MyCustomBackend(WrapperFxCodegen):
     def compile_graph(self, gm):
         # Add 1 to the graph's outputs
         def compiled_fn(*args):
             return [x + 1 for x in gm.graph.forward(*args)]
         return compiled_fn
```

# Example FX graphs

This section contains some example FX graphs generated by Inductor. The correctness of these graphs was verified against eager mode by calling the corresponding `GraphModule`.

Here's an FX graph calling a basic Triton kernel. Notice how outputs are allocated with `torch.empty_strided`, and the Triton kernel is called by reference to Dynamo's triton side table.
```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((8,), (1,)), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(8,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg1_1, in_ptr1: %arg0_1, out_ptr0: %buf0, xnumel: 8, XBLOCK: 8}})
    return (buf0,)
```

Here's a more complicated graph that calls a `torch.addmm` extern kernel.

```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=2] = placeholder[target=arg1_1]
    %buf0 : [num_users=3] = call_function[target=torch.empty_strided](args = ((), ()), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(1,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg1_1, out_ptr0: %buf0, xnumel: 1, r0_numel: 129, XBLOCK: 1}})
    %buf2 : [num_users=2] = call_function[target=torch.empty_strided](args = ((129, 1), (1, 1)), kwargs = {dtype: torch.float32, device: cuda:0})
    %addmm : [num_users=0] = call_function[target=torch.addmm](args = (%buf0, %arg0_1, %arg1_1), kwargs = {alpha: 1, beta: 1, out: %buf2})
    %delete : [num_users=0] = call_function[target=torch._inductor.codegen.wrapper_fxir.delete](args = (%buf0,), kwargs = {})
    return (buf2,)
```

Here's a graph which indexes into a tuple using `operator.getitem`. This is necessary to use the output of the `torch.topk` operation.

```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %buf0 : [num_users=3] = call_function[target=torch.ops.aten.topk.default](args = (%arg0_1, 2), kwargs = {})
    %buf1 : [num_users=2] = call_function[target=operator.getitem](args = (%buf0, 0), kwargs = {})
    %buf2 : [num_users=2] = call_function[target=operator.getitem](args = (%buf0, 1), kwargs = {})
    %delete : [num_users=0] = call_function[target=torch._inductor.codegen.wrapper_fxir.delete](args = (%buf0,), kwargs = {})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(2,)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf1, xnumel: 2, XBLOCK: 2}})
    %triton_kernel_wrapper_mutation_1 : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 1, constant_args_idx: 1, grid: [(2,)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf2, xnumel: 2, XBLOCK: 2}})
    return (buf1, buf2)
```

Here's a graph that reinterprets an output tensor using `torch.as_strided`. This is one way to handle Inductor's `ReinterpretView` op.

```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((2, 4), (4, 1)), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(8,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg0_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: 8, XBLOCK: 8}})
    %buf0_view_buf0_0 : [num_users=1] = call_function[target=torch.as_strided](args = (%buf0, (8,), (1,), 0), kwargs = {})
    return (buf0_view_buf0_0,)
```

Here's a graph with dynamic shapes. This one is a little bit funky. Inductor provides a graph input for each shape symbol, which we map to a placeholder, in this example `s6`. Then, shape expressions in the generated code can refer to the symbol `s6`. The size hint for `s6` is stored in `node.meta["val"]` where `node` is the placeholder defining it. This works out in the generated python code because the placeholder defines a Python variable with the name `s6`.
```
graph():
    %s6 : [num_users=0] = placeholder[target=s6]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((s6,), (1,)), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [[-(((-s6)//8)), 1, 1]], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg2_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: s6, XBLOCK: 8}})
    return buf0
```

Here's another graph, this time with dynamic shapes and strides. The grid expression is more complex since the numel is a product of dimensions.
```
graph():
    %s10 : [num_users=0] = placeholder[target=s10]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ([s10, s10], [s10, 1]), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [[-(((s10**2)//(-64))), 1, 1]], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg2_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: s10**2, XBLOCK: 64}})
    return buf0
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146942
Approved by: https://github.com/jansel
2025-05-06 10:06:39 +00:00
Mandar Deshpande
e3064bf0e3 [inductor] Allow num_program specification for TMA workspace (#152844)
Summary:
Allow TMA workspace creation allow specification for `num_programs`, which defaults to `num_sms` when not specified.

We need a total `num_programs * num_tma_descriptors` no. of descriptors for a kernel.

Test Plan: CI.

Differential Revision: D74189599

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152844
Approved by: https://github.com/drisspg
2025-05-05 23:02:55 +00:00
PyTorch MergeBot
99dac7005f Revert "[Inductor] FX backend via Wrapper IR (#146942)"
This reverts commit a7691140a0.

Reverted https://github.com/pytorch/pytorch/pull/146942 on behalf of https://github.com/malfet due to Looks like it indeed breaks lint, see a7691140a0/1 ([comment](https://github.com/pytorch/pytorch/pull/146942#issuecomment-2852192778))
2025-05-05 20:01:29 +00:00
Blaine Burton Rister
a7691140a0 [Inductor] FX backend via Wrapper IR (#146942)
# Sub-PRs

These PRs contain refactors from the main one. They should be reviewed and merged first.

- https://github.com/pytorch/pytorch/pull/150458
- https://github.com/pytorch/pytorch/pull/152391
- https://github.com/pytorch/pytorch/pull/152587

# Feature

The goals of this PR are twofold.

## Goal 1: Introduce Wrapper IR as an intermediate step in wrapper codegen.

In addition to Triton/C++/Halide kernels, Inductor also generates "wrapper" code which allocates memory and calls the kernels. Originally, this wrapper code was fairly standard Python which resembled a user-written PyTorch program. Over time, various wrapper code generators have been added to accommodate things like AOTInductor, which prefers C++ code for static compilation. This complexity has bled into other parts of the codebase, as we now need if/else statements to choose between Python and C++ macros. (See an example [here](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/ir.py#L5515-L5522).) Since most of these code generation steps are conceptually identical across target languages, it seems reasonable to refactor them into some kind of intermediate representation which can be shared between the various backends. This might also make it easier to develop out-of-tree backends which cannot put their own macros in core Inductor components.

This PR takes some initial steps to formalize Inductor's wrapper codegen by generalizing the existing Memory Planning IR into a fully fledged Wrapper IR. This is pretty much identical to the existing Memory Planning IR, but it supports a richer set of ops for things like kernel definitions and calls. This refactor could help encapsulate wrapper codegen. Ideally, we don't need to worry about direct Python/C++ codegen in the main compiler files such as `ir.py`, and can instead defer these to classes like `PythonWrapperCodegen` and `CppWrapperCpu`, which operate on the Wrapper IR.

## Goal 2: Convert Wrapper IR into FX IR.

One of the main benefits of Wrapper IR is to enable more diverse Inductor backends. This PR introduces a converter from Wrapper IR into [FX IR](https://pytorch.org/docs/stable/fx.html), which is the intermediate representation most commonly used in PyTorch graph compilers. The purpose of this is to enable out-of-tree backends to consume Inductor's output in FX IR, which would hopefully make Inductor easier to leverage in novel compilers, hardware accelerators, etc.

It's not trivial to generate Python or C++ code which Inductor can compile and run, and doing so may require changes to other core Inductor files, for the reasons outlined in the previous section. The goal of supporting FX output is to enable something like `torch.compile`'s [custom backend](https://pytorch.org/docs/stable/torch.compiler_custom_backends.html) system, in which an out-of-tree backend can receive an optimized FX graph from Inductor, and compile and run it however it likes.

The typical users of this feature would likely not be part of PyTorch, and may or may not support running a kernel in eager mode. However, they can understand what `torch.empty_strided` means, compile and run Triton kernels, etc. So we just need to present them with an FX graph saying what code Inductor wants to run, which should be easier to analyze and transform in a third party system than Python or C++ source.

Since FX IR is fairly stable, this mechanism should hopefully isolate third-party backends, hardware accelerators, etc. from the implementation details of Inductor, and vice versa.

# Current status

Things that seem to work:

- Converted a lot of the most common Python codegen lines to Wrapper IR lines.
     - Handled the following cases, in addition to what was already in the Memory Planning IR:
         - Comments
         - Triton kernels
         - Extern/fallback kernels
         - Freeing tensors (`del buf0`)
         - MultiOutput
         - Graph outputs
         - ReinterpretView / StorageBox, for both call args and outputs.
     - FX conversion asserts that the program only contains Wrapper IR lines, and not strings of Python/C++ code.
- Prototype FX converter which can handle some of the most common use cases.
   - Defining Triton kernels, and putting them in a side table using TorchDynamo's existing [utilities](https://dev-discuss.pytorch.org/t/higher-order-operators-2023-10/1565).
   - Calling wrapped Triton kernels.
   - Calling extern kernels and certain types of fallback kernels.
       - Support both `extern_kernels.*` and `aten.*`.
       - Support multi-output kernels like `torch.topk`.
   - Graphs with multiple inputs/outputs.
   - Training i.e. calling `Tensor.backward()` in a compiled function.
   - Graph breaks (training).
- Run the `torch.fx.GraphModule` on GPU using the standard `__call__` method. This makes it easy to test the correctness of FX codegen.

Things that don't work:
- Both Wrapper IR and Wrapper -> FX coverage are currently best effort. There are still features which aren't captured as Wrapper IR lines, and fall back to plain strings. This representation is functionally correct but probably not rich enough to achieve the goals outlined in the previous sections.
         - Fallback kernels seem like the most difficult thing to fully cover, since they each define their own Python/C++ macros that would need to be converted to FX.
         - Size/alignment asserts are currently disabled via the config file. It's possible to generate FX IR for these, but it seems reasonable to defer these sanity checks to a later PR.
         - CommBuffer's and distributed communication are not yet supported. An earlier version of this PR attempted to implement this by calling `empty_strided_p2p`. However, building and testing distributed support seems non-trivial, so it's probably better to defer this.

# Out-of-tree compilers

With this PR, out of tree backends will be able to do further compilation on the FX graphs by subclassing `WrapperFxCodegen` and overriding the `compile_graph` function. This follows the same API as torch.compile's [custom backends](https://pytorch.org/docs/stable/torch.compiler_custom_backends.html), where the user simply returns a callable running the graph. The callable need not be a method of `GraphModule` or any other PyTorch class. See an example below.

```
from torch._inductor.codegen.wrapper_fxir import WrapperFxCodegen

class MyCustomBackend(WrapperFxCodegen):
     def compile_graph(self, gm):
         # Add 1 to the graph's outputs
         def compiled_fn(*args):
             return [x + 1 for x in gm.graph.forward(*args)]
         return compiled_fn
```

# Example FX graphs

This section contains some example FX graphs generated by Inductor. The correctness of these graphs was verified against eager mode by calling the corresponding `GraphModule`.

Here's an FX graph calling a basic Triton kernel. Notice how outputs are allocated with `torch.empty_strided`, and the Triton kernel is called by reference to Dynamo's triton side table.
```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((8,), (1,)), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(8,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg1_1, in_ptr1: %arg0_1, out_ptr0: %buf0, xnumel: 8, XBLOCK: 8}})
    return (buf0,)
```

Here's a more complicated graph that calls a `torch.addmm` extern kernel.

```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=2] = placeholder[target=arg1_1]
    %buf0 : [num_users=3] = call_function[target=torch.empty_strided](args = ((), ()), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(1,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg1_1, out_ptr0: %buf0, xnumel: 1, r0_numel: 129, XBLOCK: 1}})
    %buf2 : [num_users=2] = call_function[target=torch.empty_strided](args = ((129, 1), (1, 1)), kwargs = {dtype: torch.float32, device: cuda:0})
    %addmm : [num_users=0] = call_function[target=torch.addmm](args = (%buf0, %arg0_1, %arg1_1), kwargs = {alpha: 1, beta: 1, out: %buf2})
    %delete : [num_users=0] = call_function[target=torch._inductor.codegen.wrapper_fxir.delete](args = (%buf0,), kwargs = {})
    return (buf2,)
```

Here's a graph which indexes into a tuple using `operator.getitem`. This is necessary to use the output of the `torch.topk` operation.

```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %buf0 : [num_users=3] = call_function[target=torch.ops.aten.topk.default](args = (%arg0_1, 2), kwargs = {})
    %buf1 : [num_users=2] = call_function[target=operator.getitem](args = (%buf0, 0), kwargs = {})
    %buf2 : [num_users=2] = call_function[target=operator.getitem](args = (%buf0, 1), kwargs = {})
    %delete : [num_users=0] = call_function[target=torch._inductor.codegen.wrapper_fxir.delete](args = (%buf0,), kwargs = {})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(2,)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf1, xnumel: 2, XBLOCK: 2}})
    %triton_kernel_wrapper_mutation_1 : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 1, constant_args_idx: 1, grid: [(2,)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf2, xnumel: 2, XBLOCK: 2}})
    return (buf1, buf2)
```

Here's a graph that reinterprets an output tensor using `torch.as_strided`. This is one way to handle Inductor's `ReinterpretView` op.

```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((2, 4), (4, 1)), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(8,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg0_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: 8, XBLOCK: 8}})
    %buf0_view_buf0_0 : [num_users=1] = call_function[target=torch.as_strided](args = (%buf0, (8,), (1,), 0), kwargs = {})
    return (buf0_view_buf0_0,)
```

Here's a graph with dynamic shapes. This one is a little bit funky. Inductor provides a graph input for each shape symbol, which we map to a placeholder, in this example `s6`. Then, shape expressions in the generated code can refer to the symbol `s6`. The size hint for `s6` is stored in `node.meta["val"]` where `node` is the placeholder defining it. This works out in the generated python code because the placeholder defines a Python variable with the name `s6`.
```
graph():
    %s6 : [num_users=0] = placeholder[target=s6]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((s6,), (1,)), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [[-(((-s6)//8)), 1, 1]], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg2_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: s6, XBLOCK: 8}})
    return buf0
```

Here's another graph, this time with dynamic shapes and strides. The grid expression is more complex since the numel is a product of dimensions.
```
graph():
    %s10 : [num_users=0] = placeholder[target=s10]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ([s10, s10], [s10, 1]), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [[-(((s10**2)//(-64))), 1, 1]], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg2_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: s10**2, XBLOCK: 64}})
    return buf0
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146942
Approved by: https://github.com/jansel
2025-05-05 19:34:49 +00:00
PaulZhang12
84aa0985fb [Inductor] Add decomposeK as an autotuning choice for mm (#150654)
As a result of adding subgraph as a choice to inductor https://github.com/pytorch/pytorch/pull/149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: https://github.com/pytorch/pytorch/pull/150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm
* Enable for Inference and AOTI

Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:

<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />

TorchInductor Benchmark Dashboard:
<img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" />

We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over.

Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150654
Approved by: https://github.com/eellison
2025-05-03 02:23:54 +00:00
PyTorch MergeBot
7c3e679ddd Revert "[Inductor] Add decomposeK as an autotuning choice for mm (#150654)"
This reverts commit fdcfc6a61a.

Reverted https://github.com/pytorch/pytorch/pull/150654 on behalf of https://github.com/wdvr due to Failing ROCM tests: inductor/test_subgraph_choice.py::TestSubgraphChoice::test_subgraph_decompose_k [GH job link](https://github.com/pytorch/pytorch/actions/runs/14786111108/job/41515742446) [HUD commit link](3c54e0c216) ([comment](https://github.com/pytorch/pytorch/pull/150654#issuecomment-2846470409))
2025-05-02 06:31:38 +00:00
PaulZhang12
fdcfc6a61a [Inductor] Add decomposeK as an autotuning choice for mm (#150654)
As a result of adding subgraph as a choice to inductor https://github.com/pytorch/pytorch/pull/149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: https://github.com/pytorch/pytorch/pull/150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm
* Enable for Inference and AOTI

Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:

<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />

TorchInductor Benchmark Dashboard:
<img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" />

We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over.

Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150654
Approved by: https://github.com/eellison
2025-05-01 23:01:30 +00:00
Michael Lazos
a1f6d85b36 [Cutlass] Fixes for e2e compilation in arg rendering (#151405)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151405
Approved by: https://github.com/eellison
ghstack dependencies: #152305, #152306, #150905
2025-04-29 23:06:01 +00:00
Aaron Orenstein
c8b4a39d73 Add precedence to the infix printing done by sympy_str. (#151920)
Add precedence to the infix printing done by sympy_str.

Without this change sympy_str will print the same string for both `a+b*(c+d)` and `(a+b)*(c+d)`.

While there I also cleaned up the printing for `-a` and `a - b`.

Added some tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151920
Approved by: https://github.com/jansel
2025-04-29 00:58:58 +00:00
Anthony Shoumikhin
e2f9759bd0 Fix broken URLs (#152237)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152237
Approved by: https://github.com/huydhn, https://github.com/malfet
2025-04-27 09:56:42 +00:00
PyTorch MergeBot
72f711e200 Revert "[inductor] Change minimum number of SMs to 60 to let Ada use Triton GEMM backend (#150888)"
This reverts commit 8d81806211.

Reverted https://github.com/pytorch/pytorch/pull/150888 on behalf of https://github.com/henrylhtsang due to Revert because this change isn't needed ([comment](https://github.com/pytorch/pytorch/pull/150888#issuecomment-2822768377))
2025-04-23 00:26:49 +00:00
Rachel Guo
c729f7dbee [provenance_tracking][reland] Fix UT error and re-land ExternKernel support (#151709)
Summary:
ATT.

reverted previous diff :  D72572050

Test Plan:
```
 TORCH_LOGS="+inductor, output_code" buck2 run -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=h100 @//mode/opt fbcode//caffe2/test/inductor:provenance_tracing -- -r test_triton_kernel_to_post_grad_tracing_extern_kernel
```

Differential Revision: D73281217

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151709
Approved by: https://github.com/jingsh
2025-04-22 15:44:56 +00:00
rzou
29317f8585 [standalone_compile] Some misc fixes (#151502)
This PR fixes two things.

The first problem is that in the vLLM style standalone_compile is
called from within a custom torch.compile backend. If there already is a
FakeTensorMode (which there is), we shouldn't create a new
FakeTensorMode with the same shape_env, instead we should just reuse the
same FakeTensorMode.

The second thing is that compile_fx can mutate the passed in gm, so we
deepcopy (since standalone_compile should be standalone)

Test Plan:
- new test
- updated old tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151502
Approved by: https://github.com/oulgen
ghstack dependencies: #151501, #151551
2025-04-18 12:34:13 +00:00
eellison
6d46b530fc Remove libdevice ops in inductor (#151562)
Now that we track dtypes during codegen, we can delete all these extra ops that worked around the problem by doing dispatch at lowering time.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151562
Approved by: https://github.com/isuruf, https://github.com/jansel
2025-04-17 22:18:00 +00:00
Chong Gu
a05cc9f494 Remove Clear Cache Time from do_bench_using_profiling (#150696)
Summary: In most instances, this action would take ~33% of the total run time, which means that our benchmark would previously differ from the end results by a lot.

Test Plan:
We can compare the benchmark results for
```
CUDA_VISIBLE_DEVICES=4,5 buck run mode/opt -c python.package_style=inplace -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=h100a //caffe2/torch/fb/model_transform/experimental/benchmark:mts_gpu_benchmark -- --model-snapshot-id=672308665_0 --lower-backend=AOT_INDUCTOR --node-replacement-dict="{'torch.nn.Linear':{'(autotune)': 'fp8_float_model_dynamic_quantization_rowwise'}}" --trace-aot-inductor-module=True --disable-acc-tracer=False --batch-size=1024
```
before and after the diff, and notice that on average, the benchmark results decrease by ~0.1ms per iteration, which is more closely aligned with the lowered modules.

Differential Revision: D72469845

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150696
Approved by: https://github.com/frank-wei
2025-04-17 07:25:41 +00:00
henrylhtsang
532025fbd0 [cutlass backend][ez] Ban FP32 output dtype from using CUTLASS GEMM backend (#151279)
FP32 not supported: https://github.com/pytorch/pytorch/issues/145952

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151279
Approved by: https://github.com/ColinPeppler
2025-04-16 01:12:18 +00:00
Oguz Ulgen
3cf0e2d8ec Add inductor standalone_compile API (#150670)
This PR adds standalone_compile API that does precompilation via caching to support vLLM use case in the short term while we work on the longer term precompilation solution.

```
standalone_compile(gm, example_inputs, options) -> CompiledArtifact
CompiledArtifact.save(path, format: binary|unpacked = binary)
CompiledArtifact.load(path, format: binary|unpacked = binary)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150670
Approved by: https://github.com/jamesjwu, https://github.com/zou3519
2025-04-15 23:38:15 +00:00
PyTorch MergeBot
74f6bc28a7 Revert "Add inductor standalone_compile API (#150670)"
This reverts commit c9aef50898.

Reverted https://github.com/pytorch/pytorch/pull/150670 on behalf of https://github.com/Camyll due to breaking internal builds with torch module not found error ([comment](https://github.com/pytorch/pytorch/pull/150670#issuecomment-2806975267))
2025-04-15 17:35:59 +00:00
Oguz Ulgen
c9aef50898 Add inductor standalone_compile API (#150670)
This PR adds standalone_compile API that does precompilation via caching to support vLLM use case in the short term while we work on the longer term precompilation solution.

```
standalone_compile(gm, example_inputs, options) -> CompiledArtifact
CompiledArtifact.save(path, format: binary|unpacked = binary)
CompiledArtifact.load(path, format: binary|unpacked = binary)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150670
Approved by: https://github.com/jamesjwu, https://github.com/zou3519
2025-04-14 22:00:09 +00:00
PyTorch MergeBot
24b3ab9255 Revert "Add inductor standalone_compile API (#150670)"
This reverts commit bbc5fe8504.

Reverted https://github.com/pytorch/pytorch/pull/150670 on behalf of https://github.com/albanD due to Broke profiler test ([comment](https://github.com/pytorch/pytorch/pull/150670#issuecomment-2802067144))
2025-04-14 15:22:33 +00:00
Oguz Ulgen
bbc5fe8504 Add inductor standalone_compile API (#150670)
This PR adds standalone_compile API that does precompilation via caching to support vLLM use case in the short term while we work on the longer term precompilation solution.

```
standalone_compile(gm, example_inputs, options) -> CompiledArtifact
CompiledArtifact.save(path, format: binary|unpacked = binary)
CompiledArtifact.load(path, format: binary|unpacked = binary)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150670
Approved by: https://github.com/jamesjwu, https://github.com/zou3519
2025-04-14 07:07:10 +00:00
Thomas Adams
8494d5582a Propagate callable parameter types using ParamSpec (#142306) (#151014)
Partially addresses #142306

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151014
Approved by: https://github.com/Skylion007
2025-04-13 20:38:11 +00:00
Michael Lazos
fe961679d5 [Inductor] add support for disabling atomic adds (#151033)
As title

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151033
Approved by: https://github.com/eellison, https://github.com/shunting314
2025-04-11 18:41:56 +00:00
henrylhtsang
8d81806211 [inductor] Change minimum number of SMs to 60 to let Ada use Triton GEMM backend (#150888)
context: https://github.com/pytorch/pytorch/issues/150390#issuecomment-2790272814

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150888
Approved by: https://github.com/jansel
2025-04-10 22:10:55 +00:00
PyTorch MergeBot
e786b3bf54 Revert "[inductor] Change minimum number of SMs to 60 to let Ada use Triton GEMM backend (#150888)"
This reverts commit 115a165f9b.

Reverted https://github.com/pytorch/pytorch/pull/150888 on behalf of https://github.com/malfet due to This indeed broke all those inductor tests ([comment](https://github.com/pytorch/pytorch/pull/150888#issuecomment-2795231901))
2025-04-10 21:46:23 +00:00
henrylhtsang
115a165f9b [inductor] Change minimum number of SMs to 60 to let Ada use Triton GEMM backend (#150888)
context: https://github.com/pytorch/pytorch/issues/150390#issuecomment-2790272814

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150888
Approved by: https://github.com/jansel
2025-04-10 19:46:35 +00:00
PaulZhang12
e62d958f02 [Inductor] Reland Merge Triton ScaledMM as epilogue to MM template #150045 (#150441)
Merges https://github.com/pytorch/pytorch/pull/150438 and https://github.com/pytorch/pytorch/pull/150045. https://github.com/pytorch/pytorch/pull/150045 was already landed, but did not include a change that makes it unable to land internally.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150441
Approved by: https://github.com/clee2000
2025-04-02 17:49:32 +00:00
PyTorch MergeBot
f04cf13bdd Revert "Merge Triton ScaledMM as epilogue to MM template (#150045)"
This reverts commit 981048854d.

Reverted https://github.com/pytorch/pytorch/pull/150045 on behalf of https://github.com/PaulZhang12 due to Need to add PR 150415 fixes for internal merge ([comment](https://github.com/pytorch/pytorch/pull/150045#issuecomment-2770252452))
2025-04-01 17:54:28 +00:00
PaulZhang12
981048854d Merge Triton ScaledMM as epilogue to MM template (#150045)
Previously, scaled_mm's (FP8 matmul) Triton lowering for inductor was in a separate template. This PR consolidates that lowering into the mm template, with an added epilogue to deal with multiplying the scales. This paves the way for future scaled variants of BMM, Grouped GEMM in inductor.

Currently, there is still a separate template for TMA+persistent version of scaled_mm. The current mm lowering has a separate template for TMA + Persistent version. Will hopefully consolidate the extra scaled_mm TMA+persistent template when the consolidation for the mm template is done.
TODO: Consolidate TMA+Persistent logic into 1 template and remove separate scaled_mm TMA template

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150045
Approved by: https://github.com/drisspg
2025-03-31 23:20:14 +00:00
Sam Larsen
266bd22b44 Improve subproc autotuning implementation (#149700)
Summary: The primary change is to update the autotune-in-a-subproc implementation to avoid using multiprocessing spawn. Spawn (re)executes the toplevel script in the subproc, which can be problematic. The approach here is similar to Triton parallel compile: we Popen a subproc on a controlled entry point and communicate over pipes. That change drove a lot of refactoring in the TuningProcess class, so I took the opportunity to simplify some things, rename some methods, etc.

One other notable change is around the timeout / kill approach. After a timeout, we were previously attempting to stop the subproc in three steps (graceful shutdown, sigkill if graceful fails, sigterm if sigkill fails). I'm gonna argue think that's not useful: 1) The graceful shutdown is never going to work unless the subproc happens to have just completed its task and is ready to receive the next command. 2) If we're going to kill the subproc, let's just take the most aggressive approach and move on as quickly as possible to restarting it rather than waiting to see if previous shutdown attempts succeeded. The only downside that I can find find is maybe a little log spew?, e.g., ` ResourceWarning: subprocess 2987680 is still running`

List of changes:
* Use Popen instead of spawn for the autotuning subprocess.
* Introduced a new entry point `__autotune_main__.py`
* Renamed some TuningProcess methods. For example `shutdown` makes more sense than `terminate` because the latter implies a forced kill.
* Simplified the implementation around benchmarking timeout and how we kill the subproc after a timeout.
* Deprecated the unused timeout configs in `_inductor/config.py`
* Moved `get_ld_library_path` helper to a common utils file.
* Added more unit tests for subproc crashes / timeouts / exceptions, etc.

Test plan:
* New unit tests
* Also ran internally with all combinations of: build mode `opt` and `dev-nosan`, and `buck run` vs. executing the `.par` file directly.
* Made sure the functionality to parallelize autotuning across different GPUs is working (it wasn't clear to me this was behaving the way we wanted it to).

Differential Revision: [D71976971](https://our.internmc.facebook.com/intern/diff/D71976971)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149700
Approved by: https://github.com/aorenste, https://github.com/jansel, https://github.com/eellison
2025-03-28 01:06:39 +00:00
Boyuan Feng
c830d750e6 [graph partition] support splitting on custom ops (#149782)
This PR adds support for graph partition on custom ops. Land after #149458.

### API
This PR provides a new API to register/unregister custom ops for graph partition.

```python
def register_custom_op_support_cudagraph(
    operator: torch._library.custom_ops.CustomOpDef,
    is_cudagraphable: bool,
) -> None
```

Example usage:

```python
from torch._inductor.utils import register_custom_op_partition

@torch.library.custom_op("mylib::movement", mutates_args=())
def movement(pic: torch.Tensor) -> torch.Tensor:
    img = pic.cpu()
    cropped_img = (img + 1) * 2
    return cropped_img.cuda() / 255.0

@movement.register_fake
def _(pic):
    return torch.empty_like(pic)

register_custom_op_support_cudagraph(movement, is_cudagraphable=False)
```

### Example
In this example, 1 torch-compiled region has 3 cudagraphs after splitting on 2 custom ops.

![image](https://github.com/user-attachments/assets/6d07355b-6690-4cde-89ef-e4aff6b0079c)

Code to repro:
```python
import torch
from torch._inductor.utils import register_custom_op_support_cudagraph

torch._inductor.config.graph_partition = True

@torch.library.custom_op("mylib::movement", mutates_args=())
def movement(pic: torch.Tensor) -> torch.Tensor:
    img = pic.cpu()
    cropped_img = (img + 1)*2
    return cropped_img.cuda() / 255.

@movement.register_fake
def _(pic):
    return torch.empty_like(pic)

@torch.library.custom_op("mylib::modify", mutates_args=())
def modify(pic: torch.Tensor) -> torch.Tensor:
    pic1 = pic + 1
    pic1_cpu = (pic1.cpu() + 1) * 2
    return pic1_cpu.cuda() + pic

@modify.register_fake
def _(pic):
    return torch.empty_like(pic)

@torch.library.custom_op("mylib::transform", mutates_args=())
def transform(pic: torch.Tensor) -> torch.Tensor:
    return (pic + 1) * 2

@transform.register_fake
def _(pic):
    return torch.empty_like(pic)

register_custom_op_support_cudagraph(movement, is_cudagraphable=False)
register_custom_op_support_cudagraph(modify, is_cudagraphable=False)

img = torch.randn(3, 64, 64, device="cuda")

def f(img):
    x = (img + 10) * 2
    y = movement(x)
    z = y + 1
    u = transform(z)
    v = 2*u + 1
    out = modify(v)
    return out + 1

compiled_f = torch.compile(f, mode="reduce-overhead", fullgraph=True)

eager_out = f(img)

for _ in range(3):
    compiled_out = compiled_f(img)
    assert torch.allclose(eager_out, compiled_out)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149782
Approved by: https://github.com/zou3519
2025-03-27 16:23:07 +00:00
Rachel Guo
48cff64a54 [pt2_provenance_tracing] add combo kernel nodes post_grad nodes origin info (#149598)
Summary: found it helpful when running prod model with combo_kernel feature enabled

Test Plan: CI

Differential Revision: D71513304

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149598
Approved by: https://github.com/yushangdi
2025-03-27 00:26:24 +00:00
PyTorch MergeBot
185aaaaf8e Revert "Improve subproc autotuning implementation (#149700)"
This reverts commit 8cd6a133f2.

Reverted https://github.com/pytorch/pytorch/pull/149700 on behalf of https://github.com/yangw-dev due to This is breaking servicelab_benchmark_pyper_local_runner internally ([comment](https://github.com/pytorch/pytorch/pull/149700#issuecomment-2755975959))
2025-03-26 23:17:01 +00:00
Sam Larsen
8cd6a133f2 Improve subproc autotuning implementation (#149700)
Summary: The primary change is to update the autotune-in-a-subproc implementation to avoid using multiprocessing spawn. Spawn (re)executes the toplevel script in the subproc, which can be problematic. The approach here is similar to Triton parallel compile: we Popen a subproc on a controlled entry point and communicate over pipes. That change drove a lot of refactoring in the TuningProcess class, so I took the opportunity to simplify some things, rename some methods, etc.

One other notable change is around the timeout / kill approach. After a timeout, we were previously attempting to stop the subproc in three steps (graceful shutdown, sigkill if graceful fails, sigterm if sigkill fails). I'm gonna argue think that's not useful: 1) The graceful shutdown is never going to work unless the subproc happens to have just completed its task and is ready to receive the next command. 2) If we're going to kill the subproc, let's just take the most aggressive approach and move on as quickly as possible to restarting it rather than waiting to see if previous shutdown attempts succeeded. The only downside that I can find find is maybe a little log spew?, e.g., ` ResourceWarning: subprocess 2987680 is still running`

List of changes:
* Use Popen instead of spawn for the autotuning subprocess.
* Introduced a new entry point `__autotune_main__.py`
* Renamed some TuningProcess methods. For example `shutdown` makes more sense than `terminate` because the latter implies a forced kill.
* Simplified the implementation around benchmarking timeout and how we kill the subproc after a timeout.
* Deprecated the unused timeout configs in `_inductor/config.py`
* Moved `get_ld_library_path` helper to a common utils file.
* Added more unit tests for subproc crashes / timeouts / exceptions, etc.

Test plan:
* New unit tests
* Also ran internally with all combinations of: build mode `opt` and `dev-nosan`, and `buck run` vs. executing the `.par` file directly.
* Made sure the functionality to parallelize autotuning across different GPUs is working (it wasn't clear to me this was behaving the way we wanted it to).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149700
Approved by: https://github.com/aorenste, https://github.com/jansel, https://github.com/eellison
2025-03-25 20:07:28 +00:00
Ding, Yi1
f7d1b966c2 [Inductor] Unify the data type propagation between Triton and CPP Backend (#146970)
Fixes #144246

Use `DtypePropagationOpsHandler` for CSE variables of CPP backend. In addition, add static type checking for the generated CPP code similar to the `config.test_configs.runtime_triton_dtype_assert`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146970
Approved by: https://github.com/jgong5, https://github.com/eellison, https://github.com/leslie-fang-intel
2025-03-21 17:52:51 +00:00
James Wu
7bb9c36784 Hook StaticCudaLauncher up to torch.compile (cold start) (#148890)
This hooks up the previous PR to torch.compile. Will add a config flag to hide this behind in a bit, but for now it's useful for testing purposes to have it on by default.

Inductor will automatically choose to use StaticCudaLauncher to launch triton kernels if:
- The kernel is a cuda kernel and inductor can find a cubin file associated with it
- The kernel takes less than 50 arguments
- The kernel doesn't use any special features (launch hooks, large amounts of shared memory)
- The kernel is not user defined (to be supported in a later PR)

We split CompileResult into TritonCompileResult and StaticTritonCompileResult, but have them share implementations of how they exec a python launcher. StaticTritonCompileResult's python launcher has the benefit of a simpler def_args/call_args setup, since it always filters out all constexprs before running, no matter the triton version.

Some key features of StaticTritonCompileResult:
- It is fully serializable
- It stores the minimum amount of stuff, so that later it can be cached easily
- It does not depend on any triton specific types (though it does have various triton metadata).

For now, both TritonCompileResult and StaticTritonCompileResult still `exec` custom python launchers, and use GridExpr. We can change that in the future to simplify if we'd like. For now though, this custom python codegen is good for flexibility when it comes to supporting removal of constexprs, so using it for static launching is nice to not have to pay the cost of removing constexprs at kernel runtime.

Hooking everything up to torch.compile lets me run every unit test with StaticCudaLauncher to make sure that we still pass (even if we bypass StaticCudaLauncher itself). It also lets me check for compilation/runtime performance with these changes.

Fixes #149448

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148890
Approved by: https://github.com/jansel
2025-03-20 17:32:20 +00:00
Aaron Gokaslan
a0ac63cbd9 [BE]: Apply ruff PERF403 to use dict comprehensions more often (#149257)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149257
Approved by: https://github.com/jansel
2025-03-18 00:46:07 +00:00
PyTorch MergeBot
24cfeec2c7 Revert "[BE]: Apply ruff PERF403 to use dict comprehensions more often (#149257)"
This reverts commit bfee141666.

Reverted https://github.com/pytorch/pytorch/pull/149257 on behalf of https://github.com/malfet due to Let's see if it helps restore compiler benchmark sanity, see 8bc7bd94a5/1 ([comment](https://github.com/pytorch/pytorch/pull/149257#issuecomment-2731133812))
2025-03-17 22:57:00 +00:00