Commit Graph

521 Commits

Author SHA1 Message Date
PyTorch MergeBot
eb152ab1dd Revert "Inductor logging + analysis of torch.profile (#149697)"
This reverts commit 060838c231.

Reverted https://github.com/pytorch/pytorch/pull/149697 on behalf of https://github.com/clee2000 due to broke a bunch of tests internally D76299454, probably also broke rocm inductor/test_analysis.py::TestAnalysisCUDA::test_augment_trace_against_flop_counter_maxat0_cuda_float16 [GH job link](https://github.com/pytorch/pytorch/actions/runs/15545277599/job/43766911025) [HUD commit link](060838c231) ([comment](https://github.com/pytorch/pytorch/pull/149697#issuecomment-2959747153))
2025-06-10 15:38:40 +00:00
Laith Sakka
a205e8fd73 Apply all replacements on backward graph args during inductor codegen. (#155469)
Summary: temporary mitigation for https://github.com/pytorch/pytorch/issues/155468

Test Plan:
NA

Rollback Plan:

Differential Revision: D76096355

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155469
Approved by: https://github.com/bobrenjc93
2025-06-10 08:56:18 +00:00
PyTorch MergeBot
e12597090c Revert "Update auto-tuning support for _scaled_grouped_mm (#150944)"
This reverts commit 09328eb02f.

Reverted https://github.com/pytorch/pytorch/pull/150944 on behalf of https://github.com/davidberard98 due to breaks internal usage & complicates triton pin update - more details in https://github.com/pytorch/pytorch/pull/150944#issuecomment-2957246463 ([comment](https://github.com/pytorch/pytorch/pull/150944#issuecomment-2957248841))
2025-06-09 23:12:56 +00:00
Gabriel Ferns
060838c231 Inductor logging + analysis of torch.profile (#149697)
Prereqs:
 - https://github.com/pytorch/pytorch/pull/152708

Features:
1. Adds inductor's estimate of flops and bandwidth to the json trace events that perfetto uses.
1. Only use the tflops estimation from triton if we don't have the info from the datasheet because Triton's estimates are inaccurate. I have a backlog item to fix triton flops estimation upstream. New `DeviceInfo` class, and new function `get_device_tflops`.
1. New helpers `countable_fx` and `count_flops_fx` helps get the flops of an `fx.Node`.
1. Extends Triton `torch.profiler` logging to `DebugAutotuner`.
1. New script `profile_analysis.py`: `--augment_trace` adds perf estimates to any perfetto json trace, `--analyze` creates a summary table of these perf estimates, and `--diff` will compare two traces side by side:
```python
Device(NVIDIA H100, 0):
 Kernel Name                              | resnet Kernel Count | resnet FLOPS       | resnet bw gbps        | resnet Dur (ms)    | resnet Achieved FLOPS % | resnet Achieved Bandwidth % | newresnet Kernel Count | newresnet FLOPS    | newresnet bw gbps     | newresnet Dur (ms) | newresnet Achieved FLOPS % | newresnet Achieved Bandwidth %
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
 triton_poi_fused__native_batch_norm_legi | 24                  | 0                  | 0.11395268248131513   | 2.5919166666666666 | 0                       | 0.003401572611382541        | 24                     | 0                  | 0.11395268248131513   | 2.5919166666666666 | 0                          | 0.003401572611382541
 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 142                 | 16932673552.422373 | 0.2585007824198784    | 12.441619718309857 | 0.08683422334575583     | 0.007716441266265022        | 142                    | 16932673552.422373 | 0.2585007824198784    | 12.441619718309857 | 0.08683422334575583        | 0.007716441266265022
 triton_red_fused__native_batch_norm_legi | 39                  | 0                  | 0.13990024992108846   | 5.752589743589743  | 0                       | 0.004176126863316074        | 39                     | 0                  | 0.13990024992108846   | 5.752589743589743  | 0                          | 0.004176126863316074
 triton_poi_fused__native_batch_norm_legi | 25                  | 0                  | 0.31824055917536503   | 2.5291999999999994 | 0                       | 0.009499718184339253        | 25                     | 0                  | 0.31824055917536503   | 2.5291999999999994 | 0                          | 0.009499718184339253
 void cutlass::Kernel2<cutlass_80_tensoro | 98                  | 16211056473.596165 | 0.42972434051025826   | 7.130408163265306  | 0.08313362294151874     | 0.012827592254037562        | 98                     | 16211056473.596165 | 0.42972434051025826   | 7.130408163265306  | 0.08313362294151874        | 0.012827592254037562
 triton_red_fused__native_batch_norm_legi | 73                  | 0                  | 0.3225381327611705    | 9.987068493150682  | 0                       | 0.009628003963020014        | 73                     | 0                  | 0.3225381327611705    | 9.987068493150682  | 0                          | 0.009628003963020014
 triton_poi_fused__native_batch_norm_legi | 15                  | 0                  | 1.4491211346487216    | 4.439333333333333  | 0                       | 0.043257347302946926        | 15                     | 0                  | 1.4491211346487216    | 4.439333333333333  | 0                          | 0.043257347302946926
 void cutlass::Kernel2<cutlass_80_tensoro | 186                 | 14501701145.337954 | 0.2667131401910989    | 7.873865591397849  | 0.07436769818122027     | 0.007961586274361157        | 186                    | 14501701145.337954 | 0.2667131401910989    | 7.873865591397849  | 0.07436769818122027        | 0.007961586274361157
 triton_poi_fused__native_batch_norm_legi | 33                  | 0                  | 1.4924556538193923    | 4.3101515151515155 | 0                       | 0.044550915039384846        | 33                     | 0                  | 1.4924556538193923    | 4.3101515151515155 | 0                          | 0.044550915039384846
 triton_red_fused__native_batch_norm_legi | 29                  | 0                  | 0.25562590522631107   | 6.296275862068965  | 0                       | 0.007630624036606301        | 29                     | 0                  | 0.25562590522631107   | 6.296275862068965  | 0                          | 0.007630624036606301
 triton_poi_fused__native_batch_norm_legi | 13                  | 0                  | 0.5870562174192726    | 2.7397692307692307 | 0                       | 0.01752406619162008         | 13                     | 0                  | 0.5870562174192726    | 2.7397692307692307 | 0                          | 0.01752406619162008
 triton_poi_fused__native_batch_norm_legi | 34                  | 0                  | 0.41409928846284      | 2.853588235294117  | 0                       | 0.012361172789935523        | 34                     | 0                  | 0.41409928846284      | 2.853588235294117  | 0                          | 0.012361172789935523
 triton_per_fused__native_batch_norm_legi | 34                  | 0                  | 0.11705315007018151   | 3.460647058823529  | 0                       | 0.0034941238826919864       | 34                     | 0                  | 0.11705315007018151   | 3.460647058823529  | 0                          | 0.0034941238826919864
 triton_poi_fused__native_batch_norm_legi | 16                  | 0                  | 0.17207853197124584   | 2.3459375000000002 | 0                       | 0.005136672596156592        | 16                     | 0                  | 0.17207853197124584   | 2.3459375000000002 | 0                          | 0.005136672596156592
 triton_per_fused__native_batch_norm_legi | 30                  | 0                  | 0.2639714322022256    | 6.131199999999999  | 0                       | 0.007879744244842555        | 30                     | 0                  | 0.2639714322022256    | 6.131199999999999  | 0                          | 0.007879744244842555
 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 100                 | 11875430356.891787 | 0.19494470869421385   | 16.36534           | 0.06089964285585531     | 0.005819245035648175        | 100                    | 11875430356.891787 | 0.19494470869421385   | 16.36534           | 0.06089964285585531        | 0.005819245035648175
 triton_poi_fused__native_batch_norm_legi | 8                   | 0                  | 0.9854096626224687    | 3.2757500000000004 | 0                       | 0.029415213809625928        | 8                      | 0                  | 0.9854096626224687    | 3.2757500000000004 | 0                          | 0.029415213809625928
 void cublasLt::splitKreduce_kernel<32, 1 | 56                  | 34377923395.147064 | 0.8310300045762317    | 3.4199999999999986 | 0.17629704305203628     | 0.024806865808245714        | 56                     | 34377923395.147064 | 0.8310300045762317    | 3.4199999999999986 | 0.17629704305203628        | 0.024806865808245714
 triton_poi_fused__native_batch_norm_legi | 23                  | 0                  | 0.9944002965861103    | 3.2431304347826084 | 0                       | 0.02968359094286896         | 23                     | 0                  | 0.9944002965861103    | 3.2431304347826084 | 0                          | 0.02968359094286896
 triton_per_fused__native_batch_norm_legi | 10                  | 0                  | 0.1826801058931057    | 4.428800000000001  | 0                       | 0.00545313748934644         | 10                     | 0                  | 0.1826801058931057    | 4.428800000000001  | 0                          | 0.00545313748934644
 triton_poi_fused__native_batch_norm_legi | 10                  | 0                  | 0.3168973585366449    | 2.5471999999999997 | 0                       | 0.009459622642884923        | 10                     | 0                  | 0.3168973585366449    | 2.5471999999999997 | 0                          | 0.009459622642884923
 triton_poi_fused__native_batch_norm_legi | 34                  | 0                  | 1.1463614897015777    | 4.124323529411764  | 0                       | 0.03421974596124114         | 34                     | 0                  | 1.1463614897015777    | 4.124323529411764  | 0                          | 0.03421974596124114
 void cask_plugin_cudnn::xmma_cudnn::init | 44                  | 44045510816.64277  | 2.0661232850348643    | 3.6887499999999993 | 0.22587441444432194     | 0.06167532194133924         | 44                     | 44045510816.64277  | 2.0661232850348643    | 3.6887499999999993 | 0.22587441444432194        | 0.06167532194133924
 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 95                  | 7876855400.165316  | 0.4694941555946739    | 18.224315789473682 | 0.04039413025725802     | 0.014014750913273854        | 95                     | 7876855400.165316  | 0.4694941555946739    | 18.224315789473682 | 0.04039413025725802        | 0.014014750913273854
 triton_per_fused__native_batch_norm_legi | 41                  | 0                  | 0.06825669875995298   | 3.0384146341463416 | 0                       | 0.002037513395819492        | 41                     | 0                  | 0.06825669875995298   | 3.0384146341463416 | 0                          | 0.002037513395819492
 triton_poi_fused__native_batch_norm_legi | 23                  | 0                  | 0.08808154712430301   | 2.3275652173913044 | 0                       | 0.0026292999141582997       | 23                     | 0                  | 0.08808154712430301   | 2.3275652173913044 | 0                          | 0.0026292999141582997
 triton_per_fused__native_batch_norm_legi | 40                  | 0                  | 0.18179321034952417   | 4.556825           | 0                       | 0.005426662995508183        | 40                     | 0                  | 0.18179321034952417   | 4.556825           | 0                          | 0.005426662995508183
 triton_poi_fused__native_batch_norm_legi | 15                  | 0                  | 0.5887415155454232    | 2.783866666666667  | 0                       | 0.017574373598370836        | 15                     | 0                  | 0.5887415155454232    | 2.783866666666667  | 0                          | 0.017574373598370836
 void cutlass::Kernel2<cutlass_80_tensoro | 38                  | 14242013806.264643 | 0.256592404353939     | 7.217631578947369  | 0.0730359682372546      | 0.007659474756834           | 38                     | 14242013806.264643 | 0.256592404353939     | 7.217631578947369  | 0.0730359682372546         | 0.007659474756834
 triton_poi_fused__native_batch_norm_legi | 21                  | 0                  | 0.5842860973430516    | 2.7779047619047623 | 0                       | 0.017441376040091088        | 21                     | 0                  | 0.5842860973430516    | 2.7779047619047623 | 0                          | 0.017441376040091088
 triton_per_fused__native_batch_norm_legi | 16                  | 0                  | 0.11509365173486417   | 3.5959375000000002 | 0                       | 0.0034356313950705724       | 16                     | 0                  | 0.11509365173486417   | 3.5959375000000002 | 0                          | 0.0034356313950705724
 triton_poi_fused__native_batch_norm_legi | 14                  | 0                  | 0.1704672000243914    | 2.4044285714285714 | 0                       | 0.00508857313505646         | 14                     | 0                  | 0.1704672000243914    | 2.4044285714285714 | 0                          | 0.00508857313505646
 triton_poi_fused__native_batch_norm_legi | 58                  | 0                  | 2.307520779930795     | 8.190706896551722  | 0                       | 0.06888121731136704         | 58                     | 0                  | 2.307520779930795     | 8.190706896551722  | 0                          | 0.06888121731136704
 triton_per_fused__native_batch_norm_legi | 29                  | 0                  | 0.037243248971881276  | 3.0277586206896556 | 0                       | 0.001111738775280038        | 29                     | 0                  | 0.037243248971881276  | 3.0277586206896556 | 0                          | 0.001111738775280038
 triton_poi_fused__native_batch_norm_legi | 20                  | 0                  | 0.04741699795428918   | 2.2911500000000005 | 0                       | 0.0014154327747549007       | 20                     | 0                  | 0.04741699795428918   | 2.2911500000000005 | 0                          | 0.0014154327747549007
 triton_per_fused__native_batch_norm_legi | 25                  | 0                  | 0.13357016893727824   | 3.37536            | 0                       | 0.003987169222008305        | 25                     | 0                  | 0.13357016893727824   | 3.37536            | 0                          | 0.003987169222008305
 triton_poi_fused__native_batch_norm_legi | 13                  | 0                  | 0.3089862268300253    | 2.8111538461538457 | 0                       | 0.009223469457612694        | 13                     | 0                  | 0.3089862268300253    | 2.8111538461538457 | 0                          | 0.009223469457612694
 triton_poi_fused__native_batch_norm_legi | 17                  | 0                  | 0.3129385387909844    | 2.673              | 0                       | 0.009341448919133863        | 17                     | 0                  | 0.3129385387909844    | 2.673              | 0                          | 0.009341448919133863
 triton_per_fused__native_batch_norm_legi | 19                  | 0                  | 0.2215568162533158    | 3.8837368421052636 | 0                       | 0.0066136363060691275       | 19                     | 0                  | 0.2215568162533158    | 3.8837368421052636 | 0                          | 0.0066136363060691275
 std::enable_if<!(false), void>::type int | 23                  | 504916805.19297093 | 1.0118296096314707    | 8.113913043478261  | 0.0025893169497075447   | 0.030203868944223014        | 23                     | 504916805.19297093 | 1.0118296096314707    | 8.113913043478261  | 0.0025893169497075447      | 0.030203868944223014
 triton_poi_fused_add_copy__38            | 56                  | 0                  | 0                     | 2.132482142857143  | 0                       | 0                           | 56                     | 0                  | 0                     | 2.132482142857143  | 0                          | 0
 triton_poi_fused_convolution_0           | 18                  | 0                  | 0.43458610794936897   | 2.773333333333334  | 0                       | 0.012972719640279667        | 18                     | 0                  | 0.43458610794936897   | 2.773333333333334  | 0                          | 0.012972719640279667
 triton_poi_fused_convolution_1           | 17                  | 0                  | 0.028816312469162712  | 2.6145882352941174 | 0                       | 0.0008601884319153051       | 17                     | 0                  | 0.028816312469162712  | 2.6145882352941174 | 0                          | 0.0008601884319153051
 void convolve_common_engine_float_NHWC<f | 44                  | 8641868995.31118   | 0.024730540008465626  | 25.87327272727273  | 0.04431727689903169     | 0.0007382250748795709       | 44                     | 8641868995.31118   | 0.024730540008465626  | 25.87327272727273  | 0.04431727689903169        | 0.0007382250748795709
 triton_per_fused__native_batch_norm_legi | 12                  | 0                  | 0.6809930918986744    | 4.82675            | 0                       | 0.020328151996975356        | 12                     | 0                  | 0.6809930918986744    | 4.82675            | 0                          | 0.020328151996975356
 triton_per_fused__native_batch_norm_legi | 14                  | 0                  | 0.02883030597936608   | 2.6651428571428575 | 0                       | 0.0008606061486377935       | 14                     | 0                  | 0.02883030597936608   | 2.6651428571428575 | 0                          | 0.0008606061486377935
 triton_per_fused__native_batch_norm_legi | 16                  | 0                  | 0.0014658988233201874 | 2.098              | 0                       | 4.375817383045335e-05       | 16                     | 0                  | 0.0014658988233201874 | 2.098              | 0                          | 4.375817383045335e-05
 triton_poi_fused__native_batch_norm_legi | 13                  | 0                  | 0.9926297180284697    | 3.2367692307692306 | 0                       | 0.02963073785159611         | 13                     | 0                  | 0.9926297180284697    | 3.2367692307692306 | 0                          | 0.02963073785159611
 triton_poi_fused__native_batch_norm_legi | 9                   | 0                  | 1.3008817095666507    | 3.0863333333333336 | 0                       | 0.03883228983781048         | 9                      | 0                  | 1.3008817095666507    | 3.0863333333333336 | 0                          | 0.03883228983781048
 void at::native::(anonymous namespace):: | 98                  | 0                  | 0.09174335613709389   | 4.408520408163265  | 0                       | 0.0027386076458833994       | 98                     | 0                  | 0.09174335613709389   | 4.408520408163265  | 0                          | 0.0027386076458833994
 void at::native::vectorized_elementwise_ | 7                   | 0                  | 0                     | 1.7278571428571428 | 0                       | 0                           | 7                      | 0                  | 0                     | 1.7278571428571428 | 0                          | 0
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149697
Approved by: https://github.com/eellison, https://github.com/shunting314
2025-06-09 21:43:21 +00:00
Aleksandar Samardžić
09328eb02f Update auto-tuning support for _scaled_grouped_mm (#150944)
1. Enable strided inputs
2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs
3. Fix non-TMA load variant
4. Replace experimental_device_tensormap_create2d with _experimental_make_tensor_descriptor
5. Fix cases when group size along K dimension is not multiple of block size along K
6. Updated meta registration
7. Update synthetic offsets creation

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150944
Approved by: https://github.com/ngimel
2025-06-08 10:18:13 +00:00
Aaron Gokaslan
83d22256f8 [BE][Ez]: Improve typing in torch._logging (#155345)
Add a few missing returns in torch._logging and use ruff to infer the obvious ones.
LazyStr now properly checks the return type of the Callable and the args and kwargs passed to it

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155345
Approved by: https://github.com/ezyang
2025-06-07 00:04:39 +00:00
PyTorch MergeBot
7e4c097b07 Revert "[inductor] Add typing to _inductor/ir.py (#149958)"
This reverts commit 529e0357c6.

Reverted https://github.com/pytorch/pytorch/pull/149958 on behalf of https://github.com/malfet due to Looks like it broke inductor_torchbind tests, due to more graphbreaks, see b0fbbef136/1 ([comment](https://github.com/pytorch/pytorch/pull/149958#issuecomment-2949583209))
2025-06-06 15:19:16 +00:00
Tom Ritchford
529e0357c6 [inductor] Add typing to _inductor/ir.py (#149958)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149958
Approved by: https://github.com/Skylion007
2025-06-06 14:15:01 +00:00
eellison
0827464002 Replace runtime type parameterization (#155221)
See:

```
>>> import timeit; print(f"OrderedSet[str](): {timeit.timeit('OrderedSet[str]()', setup='from torch.utils._ordered_set import OrderedSet', number=1000000):.6f}s, OrderedSet(): {timeit.timeit('OrderedSet()', setup='from torch.utils._ordered_set import OrderedSet', number=1000000):.6f}s")
```
> `OrderedSet[str]()`: 0.354622s, OrderedSet(): 0.095376s

Type parameterization should be on type hint, not in runtime.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155221
Approved by: https://github.com/Skylion007, https://github.com/jansel
2025-06-05 21:43:54 +00:00
PyTorch MergeBot
5e03433443 Revert "Inductor logging + analysis of torch.profile (#149697)"
This reverts commit e5afbe3124.

Reverted https://github.com/pytorch/pytorch/pull/149697 on behalf of https://github.com/malfet due to Broke rocm, see 642687af29/1 ([comment](https://github.com/pytorch/pytorch/pull/149697#issuecomment-2942415600))
2025-06-05 01:38:13 +00:00
Gabriel Ferns
e5afbe3124 Inductor logging + analysis of torch.profile (#149697)
Prereqs:
 - https://github.com/pytorch/pytorch/pull/152708

Features:
1. Adds inductor's estimate of flops and bandwidth to the json trace events that perfetto uses.
1. Only use the tflops estimation from triton if we don't have the info from the datasheet because Triton's estimates are inaccurate. I have a backlog item to fix triton flops estimation upstream. New `DeviceInfo` class, and new function `get_device_tflops`.
1. New helpers `countable_fx` and `count_flops_fx` helps get the flops of an `fx.Node`.
1. Extends Triton `torch.profiler` logging to `DebugAutotuner`.
1. New script `profile_analysis.py`: `--augment_trace` adds perf estimates to any perfetto json trace, `--analyze` creates a summary table of these perf estimates, and `--diff` will compare two traces side by side:
```python
Device(NVIDIA H100, 0):
 Kernel Name                              | resnet Kernel Count | resnet FLOPS       | resnet bw gbps        | resnet Dur (ms)    | resnet Achieved FLOPS % | resnet Achieved Bandwidth % | newresnet Kernel Count | newresnet FLOPS    | newresnet bw gbps     | newresnet Dur (ms) | newresnet Achieved FLOPS % | newresnet Achieved Bandwidth %
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
 triton_poi_fused__native_batch_norm_legi | 24                  | 0                  | 0.11395268248131513   | 2.5919166666666666 | 0                       | 0.003401572611382541        | 24                     | 0                  | 0.11395268248131513   | 2.5919166666666666 | 0                          | 0.003401572611382541
 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 142                 | 16932673552.422373 | 0.2585007824198784    | 12.441619718309857 | 0.08683422334575583     | 0.007716441266265022        | 142                    | 16932673552.422373 | 0.2585007824198784    | 12.441619718309857 | 0.08683422334575583        | 0.007716441266265022
 triton_red_fused__native_batch_norm_legi | 39                  | 0                  | 0.13990024992108846   | 5.752589743589743  | 0                       | 0.004176126863316074        | 39                     | 0                  | 0.13990024992108846   | 5.752589743589743  | 0                          | 0.004176126863316074
 triton_poi_fused__native_batch_norm_legi | 25                  | 0                  | 0.31824055917536503   | 2.5291999999999994 | 0                       | 0.009499718184339253        | 25                     | 0                  | 0.31824055917536503   | 2.5291999999999994 | 0                          | 0.009499718184339253
 void cutlass::Kernel2<cutlass_80_tensoro | 98                  | 16211056473.596165 | 0.42972434051025826   | 7.130408163265306  | 0.08313362294151874     | 0.012827592254037562        | 98                     | 16211056473.596165 | 0.42972434051025826   | 7.130408163265306  | 0.08313362294151874        | 0.012827592254037562
 triton_red_fused__native_batch_norm_legi | 73                  | 0                  | 0.3225381327611705    | 9.987068493150682  | 0                       | 0.009628003963020014        | 73                     | 0                  | 0.3225381327611705    | 9.987068493150682  | 0                          | 0.009628003963020014
 triton_poi_fused__native_batch_norm_legi | 15                  | 0                  | 1.4491211346487216    | 4.439333333333333  | 0                       | 0.043257347302946926        | 15                     | 0                  | 1.4491211346487216    | 4.439333333333333  | 0                          | 0.043257347302946926
 void cutlass::Kernel2<cutlass_80_tensoro | 186                 | 14501701145.337954 | 0.2667131401910989    | 7.873865591397849  | 0.07436769818122027     | 0.007961586274361157        | 186                    | 14501701145.337954 | 0.2667131401910989    | 7.873865591397849  | 0.07436769818122027        | 0.007961586274361157
 triton_poi_fused__native_batch_norm_legi | 33                  | 0                  | 1.4924556538193923    | 4.3101515151515155 | 0                       | 0.044550915039384846        | 33                     | 0                  | 1.4924556538193923    | 4.3101515151515155 | 0                          | 0.044550915039384846
 triton_red_fused__native_batch_norm_legi | 29                  | 0                  | 0.25562590522631107   | 6.296275862068965  | 0                       | 0.007630624036606301        | 29                     | 0                  | 0.25562590522631107   | 6.296275862068965  | 0                          | 0.007630624036606301
 triton_poi_fused__native_batch_norm_legi | 13                  | 0                  | 0.5870562174192726    | 2.7397692307692307 | 0                       | 0.01752406619162008         | 13                     | 0                  | 0.5870562174192726    | 2.7397692307692307 | 0                          | 0.01752406619162008
 triton_poi_fused__native_batch_norm_legi | 34                  | 0                  | 0.41409928846284      | 2.853588235294117  | 0                       | 0.012361172789935523        | 34                     | 0                  | 0.41409928846284      | 2.853588235294117  | 0                          | 0.012361172789935523
 triton_per_fused__native_batch_norm_legi | 34                  | 0                  | 0.11705315007018151   | 3.460647058823529  | 0                       | 0.0034941238826919864       | 34                     | 0                  | 0.11705315007018151   | 3.460647058823529  | 0                          | 0.0034941238826919864
 triton_poi_fused__native_batch_norm_legi | 16                  | 0                  | 0.17207853197124584   | 2.3459375000000002 | 0                       | 0.005136672596156592        | 16                     | 0                  | 0.17207853197124584   | 2.3459375000000002 | 0                          | 0.005136672596156592
 triton_per_fused__native_batch_norm_legi | 30                  | 0                  | 0.2639714322022256    | 6.131199999999999  | 0                       | 0.007879744244842555        | 30                     | 0                  | 0.2639714322022256    | 6.131199999999999  | 0                          | 0.007879744244842555
 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 100                 | 11875430356.891787 | 0.19494470869421385   | 16.36534           | 0.06089964285585531     | 0.005819245035648175        | 100                    | 11875430356.891787 | 0.19494470869421385   | 16.36534           | 0.06089964285585531        | 0.005819245035648175
 triton_poi_fused__native_batch_norm_legi | 8                   | 0                  | 0.9854096626224687    | 3.2757500000000004 | 0                       | 0.029415213809625928        | 8                      | 0                  | 0.9854096626224687    | 3.2757500000000004 | 0                          | 0.029415213809625928
 void cublasLt::splitKreduce_kernel<32, 1 | 56                  | 34377923395.147064 | 0.8310300045762317    | 3.4199999999999986 | 0.17629704305203628     | 0.024806865808245714        | 56                     | 34377923395.147064 | 0.8310300045762317    | 3.4199999999999986 | 0.17629704305203628        | 0.024806865808245714
 triton_poi_fused__native_batch_norm_legi | 23                  | 0                  | 0.9944002965861103    | 3.2431304347826084 | 0                       | 0.02968359094286896         | 23                     | 0                  | 0.9944002965861103    | 3.2431304347826084 | 0                          | 0.02968359094286896
 triton_per_fused__native_batch_norm_legi | 10                  | 0                  | 0.1826801058931057    | 4.428800000000001  | 0                       | 0.00545313748934644         | 10                     | 0                  | 0.1826801058931057    | 4.428800000000001  | 0                          | 0.00545313748934644
 triton_poi_fused__native_batch_norm_legi | 10                  | 0                  | 0.3168973585366449    | 2.5471999999999997 | 0                       | 0.009459622642884923        | 10                     | 0                  | 0.3168973585366449    | 2.5471999999999997 | 0                          | 0.009459622642884923
 triton_poi_fused__native_batch_norm_legi | 34                  | 0                  | 1.1463614897015777    | 4.124323529411764  | 0                       | 0.03421974596124114         | 34                     | 0                  | 1.1463614897015777    | 4.124323529411764  | 0                          | 0.03421974596124114
 void cask_plugin_cudnn::xmma_cudnn::init | 44                  | 44045510816.64277  | 2.0661232850348643    | 3.6887499999999993 | 0.22587441444432194     | 0.06167532194133924         | 44                     | 44045510816.64277  | 2.0661232850348643    | 3.6887499999999993 | 0.22587441444432194        | 0.06167532194133924
 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 95                  | 7876855400.165316  | 0.4694941555946739    | 18.224315789473682 | 0.04039413025725802     | 0.014014750913273854        | 95                     | 7876855400.165316  | 0.4694941555946739    | 18.224315789473682 | 0.04039413025725802        | 0.014014750913273854
 triton_per_fused__native_batch_norm_legi | 41                  | 0                  | 0.06825669875995298   | 3.0384146341463416 | 0                       | 0.002037513395819492        | 41                     | 0                  | 0.06825669875995298   | 3.0384146341463416 | 0                          | 0.002037513395819492
 triton_poi_fused__native_batch_norm_legi | 23                  | 0                  | 0.08808154712430301   | 2.3275652173913044 | 0                       | 0.0026292999141582997       | 23                     | 0                  | 0.08808154712430301   | 2.3275652173913044 | 0                          | 0.0026292999141582997
 triton_per_fused__native_batch_norm_legi | 40                  | 0                  | 0.18179321034952417   | 4.556825           | 0                       | 0.005426662995508183        | 40                     | 0                  | 0.18179321034952417   | 4.556825           | 0                          | 0.005426662995508183
 triton_poi_fused__native_batch_norm_legi | 15                  | 0                  | 0.5887415155454232    | 2.783866666666667  | 0                       | 0.017574373598370836        | 15                     | 0                  | 0.5887415155454232    | 2.783866666666667  | 0                          | 0.017574373598370836
 void cutlass::Kernel2<cutlass_80_tensoro | 38                  | 14242013806.264643 | 0.256592404353939     | 7.217631578947369  | 0.0730359682372546      | 0.007659474756834           | 38                     | 14242013806.264643 | 0.256592404353939     | 7.217631578947369  | 0.0730359682372546         | 0.007659474756834
 triton_poi_fused__native_batch_norm_legi | 21                  | 0                  | 0.5842860973430516    | 2.7779047619047623 | 0                       | 0.017441376040091088        | 21                     | 0                  | 0.5842860973430516    | 2.7779047619047623 | 0                          | 0.017441376040091088
 triton_per_fused__native_batch_norm_legi | 16                  | 0                  | 0.11509365173486417   | 3.5959375000000002 | 0                       | 0.0034356313950705724       | 16                     | 0                  | 0.11509365173486417   | 3.5959375000000002 | 0                          | 0.0034356313950705724
 triton_poi_fused__native_batch_norm_legi | 14                  | 0                  | 0.1704672000243914    | 2.4044285714285714 | 0                       | 0.00508857313505646         | 14                     | 0                  | 0.1704672000243914    | 2.4044285714285714 | 0                          | 0.00508857313505646
 triton_poi_fused__native_batch_norm_legi | 58                  | 0                  | 2.307520779930795     | 8.190706896551722  | 0                       | 0.06888121731136704         | 58                     | 0                  | 2.307520779930795     | 8.190706896551722  | 0                          | 0.06888121731136704
 triton_per_fused__native_batch_norm_legi | 29                  | 0                  | 0.037243248971881276  | 3.0277586206896556 | 0                       | 0.001111738775280038        | 29                     | 0                  | 0.037243248971881276  | 3.0277586206896556 | 0                          | 0.001111738775280038
 triton_poi_fused__native_batch_norm_legi | 20                  | 0                  | 0.04741699795428918   | 2.2911500000000005 | 0                       | 0.0014154327747549007       | 20                     | 0                  | 0.04741699795428918   | 2.2911500000000005 | 0                          | 0.0014154327747549007
 triton_per_fused__native_batch_norm_legi | 25                  | 0                  | 0.13357016893727824   | 3.37536            | 0                       | 0.003987169222008305        | 25                     | 0                  | 0.13357016893727824   | 3.37536            | 0                          | 0.003987169222008305
 triton_poi_fused__native_batch_norm_legi | 13                  | 0                  | 0.3089862268300253    | 2.8111538461538457 | 0                       | 0.009223469457612694        | 13                     | 0                  | 0.3089862268300253    | 2.8111538461538457 | 0                          | 0.009223469457612694
 triton_poi_fused__native_batch_norm_legi | 17                  | 0                  | 0.3129385387909844    | 2.673              | 0                       | 0.009341448919133863        | 17                     | 0                  | 0.3129385387909844    | 2.673              | 0                          | 0.009341448919133863
 triton_per_fused__native_batch_norm_legi | 19                  | 0                  | 0.2215568162533158    | 3.8837368421052636 | 0                       | 0.0066136363060691275       | 19                     | 0                  | 0.2215568162533158    | 3.8837368421052636 | 0                          | 0.0066136363060691275
 std::enable_if<!(false), void>::type int | 23                  | 504916805.19297093 | 1.0118296096314707    | 8.113913043478261  | 0.0025893169497075447   | 0.030203868944223014        | 23                     | 504916805.19297093 | 1.0118296096314707    | 8.113913043478261  | 0.0025893169497075447      | 0.030203868944223014
 triton_poi_fused_add_copy__38            | 56                  | 0                  | 0                     | 2.132482142857143  | 0                       | 0                           | 56                     | 0                  | 0                     | 2.132482142857143  | 0                          | 0
 triton_poi_fused_convolution_0           | 18                  | 0                  | 0.43458610794936897   | 2.773333333333334  | 0                       | 0.012972719640279667        | 18                     | 0                  | 0.43458610794936897   | 2.773333333333334  | 0                          | 0.012972719640279667
 triton_poi_fused_convolution_1           | 17                  | 0                  | 0.028816312469162712  | 2.6145882352941174 | 0                       | 0.0008601884319153051       | 17                     | 0                  | 0.028816312469162712  | 2.6145882352941174 | 0                          | 0.0008601884319153051
 void convolve_common_engine_float_NHWC<f | 44                  | 8641868995.31118   | 0.024730540008465626  | 25.87327272727273  | 0.04431727689903169     | 0.0007382250748795709       | 44                     | 8641868995.31118   | 0.024730540008465626  | 25.87327272727273  | 0.04431727689903169        | 0.0007382250748795709
 triton_per_fused__native_batch_norm_legi | 12                  | 0                  | 0.6809930918986744    | 4.82675            | 0                       | 0.020328151996975356        | 12                     | 0                  | 0.6809930918986744    | 4.82675            | 0                          | 0.020328151996975356
 triton_per_fused__native_batch_norm_legi | 14                  | 0                  | 0.02883030597936608   | 2.6651428571428575 | 0                       | 0.0008606061486377935       | 14                     | 0                  | 0.02883030597936608   | 2.6651428571428575 | 0                          | 0.0008606061486377935
 triton_per_fused__native_batch_norm_legi | 16                  | 0                  | 0.0014658988233201874 | 2.098              | 0                       | 4.375817383045335e-05       | 16                     | 0                  | 0.0014658988233201874 | 2.098              | 0                          | 4.375817383045335e-05
 triton_poi_fused__native_batch_norm_legi | 13                  | 0                  | 0.9926297180284697    | 3.2367692307692306 | 0                       | 0.02963073785159611         | 13                     | 0                  | 0.9926297180284697    | 3.2367692307692306 | 0                          | 0.02963073785159611
 triton_poi_fused__native_batch_norm_legi | 9                   | 0                  | 1.3008817095666507    | 3.0863333333333336 | 0                       | 0.03883228983781048         | 9                      | 0                  | 1.3008817095666507    | 3.0863333333333336 | 0                          | 0.03883228983781048
 void at::native::(anonymous namespace):: | 98                  | 0                  | 0.09174335613709389   | 4.408520408163265  | 0                       | 0.0027386076458833994       | 98                     | 0                  | 0.09174335613709389   | 4.408520408163265  | 0                          | 0.0027386076458833994
 void at::native::vectorized_elementwise_ | 7                   | 0                  | 0                     | 1.7278571428571428 | 0                       | 0                           | 7                      | 0                  | 0                     | 1.7278571428571428 | 0                          | 0
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149697
Approved by: https://github.com/eellison, https://github.com/shunting314
2025-06-04 20:03:46 +00:00
Laith Sakka
853958f82c Fix: Replacements can cause runtime assertions to disappear and can cause invalid inductor code. (#153661)
Lets explore firs a couple of problem related to replacements and runtime assertions.

#### example problem 1
if we have a runtime assertions that u0==s0, u0 is an input coming from mark_unbacked. A replacement u0=s0 will be added, the function f(u0, s0) will become f(s0, s0), this leads to the assert  not being inserted during insert_deferred_runtime_asserts.
The reason is that insert_deferred_runtime_asserts logic insert each assertion once all its inputs are seen,  but u0 will never be seen. Same thing can happen when we defer assertion on backed i.e: s0==s2 ..etc.

#### example problem 2
Consider u0==s0, where u0 is coming from a call to .item() Imagine later on that a specialization happens to s0 to become 2. In that case s0 as input wont be seen during insert_deferred_runtime_asserts and the assertion won't be inserted in the graph. Worse, Inductor will generate some code that refers to s0 in the cpp wrapper while it does not exist, causing a failure.
internal xref: https://fb.workplace.com/groups/1075192433118967/permalink/1669766396994898/

## The solution :
Runtime assertions insertion loops depend on detecting that the symbols that are used in the runtime assertions are seen, note that those symbols are either graph inputs or generated in the graph from data dependent ops like .item().

The issues above happen when symbols are graph inputs, in order to force the symbols to exist in the graph and to be seen by the runtime assertions we do not do replacements on placeholders expressions during codegen and during runtime assertions insertion.

This should not have performance overhead, since we already optimized the graph with replacements, the only effect is not mistakenly dropping graph inputs that are used in runtime assertions.
I added extended testing. A solo unrelated follow up that I noticed, is that we might want to rename unbacked symbols in runtime assertions when we do unbacked renaming, but that's a different issue.

Other approaches that did not work :
#### ban replacements on unbacked.
1. does not work when we defer runtime assertions on backed ex: s0==s1. we could also ban such replacements
but problem 2 becomes more problematic.
2. Problem two, it affects the quality of reasoning ! in a bad way.

#### Apply specialization on runtime assertions before codegen .
1. Can fix some issues, but may lead also to runtime assertions becoming NOPs.
2. Does not fix the issue if not inserting runtime assertions during insert_deferred_runtime_asserts due to input not being detected.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153661
Approved by: https://github.com/jansel
2025-05-28 09:08:05 +00:00
Benjamin Glass
768cb734ec cpp_wrapper: build non-performance-sensitive code at O1 (#148773)
Builds on #148212, applying the same improvements to `cpp_wrapper` mode.

Benchmark results:

* [A100 Benchmarks](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Wed%2C%2014%20May%202025%2015%3A10%3A05%20GMT&stopTime=Wed%2C%2021%20May%202025%2015%3A10%3A05%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(a100)&lBranch=gh/benjaminglass1/77/orig&lCommit=ca7d0a3f16e3c511534d2cd03d695be8524570d3&rBranch=main&rCommit=1075bb37d34e483763a09c7810790d5491441e13)
* [x86 Benchmarks](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Wed%2C%2014%20May%202025%2015%3A10%3A05%20GMT&stopTime=Wed%2C%2021%20May%202025%2015%3A10%3A05%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cpu%20(x86)&lBranch=gh/benjaminglass1/77/orig&lCommit=ca7d0a3f16e3c511534d2cd03d695be8524570d3&rBranch=main&rCommit=1075bb37d34e483763a09c7810790d5491441e13)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148773
Approved by: https://github.com/desertfire
2025-05-23 00:51:20 +00:00
PyTorch MergeBot
261897734a Revert "cpp_wrapper: build non-performance-sensitive code at O1 (#148773)"
This reverts commit 3c89cfd460.

Reverted https://github.com/pytorch/pytorch/pull/148773 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it seems that pr_time_benchmark is regressed after this land ([comment](https://github.com/pytorch/pytorch/pull/148773#issuecomment-2899545140))
2025-05-22 00:11:14 +00:00
Benjamin Glass
3c89cfd460 cpp_wrapper: build non-performance-sensitive code at O1 (#148773)
Builds on #148212, applying the same improvements to `cpp_wrapper` mode.

Benchmark results:

* [A100 Benchmarks](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Wed%2C%2014%20May%202025%2015%3A10%3A05%20GMT&stopTime=Wed%2C%2021%20May%202025%2015%3A10%3A05%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(a100)&lBranch=gh/benjaminglass1/77/orig&lCommit=ca7d0a3f16e3c511534d2cd03d695be8524570d3&rBranch=main&rCommit=1075bb37d34e483763a09c7810790d5491441e13)
* [x86 Benchmarks](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Wed%2C%2014%20May%202025%2015%3A10%3A05%20GMT&stopTime=Wed%2C%2021%20May%202025%2015%3A10%3A05%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cpu%20(x86)&lBranch=gh/benjaminglass1/77/orig&lCommit=ca7d0a3f16e3c511534d2cd03d695be8524570d3&rBranch=main&rCommit=1075bb37d34e483763a09c7810790d5491441e13)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148773
Approved by: https://github.com/desertfire
2025-05-21 20:23:04 +00:00
PaulZhang12
a7c01d7f13 [Inductor] Subgraph check output strides (#153755)
Make sure outputs strides of subgraph consistent with original gm. Without checking strides, it was possible for subgraph to produce nans with a reinterpret tensor on the output of the subgraph output, in which itself was not contiguous.

Differential Revision: [D74691119](https://our.internmc.facebook.com/intern/diff/D74691119/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153755
Approved by: https://github.com/eellison
ghstack dependencies: #153754
2025-05-20 16:07:18 +00:00
Shangdi Yu
ee326137a9 [reland] Add graph module runtime asserts to AOTI (#153182)
Summary:
Solves https://github.com/pytorch/pytorch/issues/151925

A reland of https://github.com/pytorch/pytorch/pull/152125.

added a try-except around the justknob internally. Also added more documentation

Currently, AOTI only generate runtime asserts for unbacked symints. We should generate asserts for all `_assert_scalar` calls in the input graph.

Also factored out the run time assertion logic to a separate function.

        We need to generate runtime asserts directly in Inductor instead of just re-using the asserts from input graphs becase we reuse the same ShapeEnv as before. In particular, on subsequent graph passes, we would immediately turn all of these assertions into noops,
because when we evaluated their expressions, we would see that because we had a deferred runtime assert in the ShapeEnv, we know "oh, of course this expression is True" already.

One example is below:
```
        class Model(torch.nn.Module):
            def forward(self, a, b, c):
                nz = torch.nonzero(a)
                ones = a.new_ones([nz.size(0), b.size(0)])
                torch._check(ones.size(0) >= 1)
                equals = torch.add(ones, c)
                return equals
        torch._dynamo.mark_dynamic(c, 0)
```
When we re-use the ShapeEnv in Inductor lowering, the check that checks a and nonzero have the same shape would be evaluted to True after we resolve unbacked bindings using the ShapeEnv.
See `test_unbacked_equals_input_size_runtime_assertion` in test_aot_inductor.

In addition to the Inductor generated runtime asserts, we also need the runtime asserts from the input graph, because some derived runtime asserts are not generated in Inductor. One example is below:
```
        class Model(torch.nn.Module):
            def forward(self, x):
                y = x.reshape(100, -1).clone()
                y = y + 1
                return y

        dynamic_shapes = {
            "x": {0: torch.export.Dim.DYNAMIC},
        }
        x.shape[0] needs to be a multiple of 100.
```
See `test_aoti_runtime_asserts_backed_symint` in test_aot_inductor.

Example:

```
    def forward(self):
        arg0_1: "f32[s35]";

        arg0_1, = fx_pytree.tree_flatten_spec([], self._in_spec)
         # File: /data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/73a672eb896e7996/scripts/shangdiy/__pt__/pt#link-tree/scripts/shangdiy/pt.py:11 in forward, code: y = x.reshape(100, -1).clone()
        sym_size_int: "Sym(s35)" = torch.ops.aten.sym_size.int(arg0_1, 0)

         #
        mod: "Sym(Mod(s35, 100))" = sym_size_int % 100;  sym_size_int = None
        eq_2: "Sym(Eq(Mod(s35, 100), 0))" = mod == 0;  mod = None
        _assert_scalar = torch.ops.aten._assert_scalar.default(eq_2, "Runtime assertion failed for expression Eq(Mod(s35, 100), 0) on node 'eq'");  eq_2 = _assert_scalar = None

         # File: /data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/73a672eb896e7996/scripts/shangdiy/__pt__/pt#link-tree/scripts/shangdiy/pt.py:11 in forward, code: y = x.reshape(100, -1).clone()
        view: "f32[100, (s35//100)]" = torch.ops.aten.reshape.default(arg0_1, [100, -1]);  arg0_1 = None
        clone: "f32[100, (s35//100)]" = torch.ops.aten.clone.default(view);  view = None

         # File: /data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/73a672eb896e7996/scripts/shangdiy/__pt__/pt#link-tree/scripts/shangdiy/pt.py:12 in forward, code: y = y + 1
        add_6: "f32[100, 1]" = torch.ops.aten.add.Tensor(clone, 1);  clone = None
        return (add_6,)
```

Generated cpp code:

```
    auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, 1);
    auto arg0_1 = std::move(inputs[0]);
    auto arg0_1_size = arg0_1.sizes();
    int64_t s35 = arg0_1_size[0];
    inputs.clear();
    auto& kernels = static_cast<AOTInductorModelKernels&>(*this->kernels_.get());
    if (!((s35 % 100L) == 0L)) { throw std::runtime_error("Expected Eq(Mod(s35, 100), 0) to be True but received " + std::to_string(s35)); }
```

Test Plan:
```
buck run fbcode//mode/dev-nosan //caffe2/test/inductor:test_aot_inductor -- -r aoti_runtime_asserts_backed_symint
buck run fbcode//mode/dev-nosan //caffe2/test/inductor:torchinductor_dynamic_shapes -- -r test_unbacked_floordiv_simplify
TORCHINDUCTOR_SCALAR_ASSERTS_FULL=1 buck run fbcode//mode/dev-nosan //caffe2/test/inductor:test_aot_inductor -- -r test_sym_i64_input_codegen_cuda
TORCHINDUCTOR_SCALAR_ASSERTS_FULL=1  buck run fbcode//mode/dev-nosan //caffe2/test/inductor:test_aot_inductor -- -r  test_unbacked_equals_input_size
```

Differential Revision: D74361799

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153182
Approved by: https://github.com/henrylhtsang
2025-05-09 22:56:19 +00:00
Mu-Chu Lee
b3524080dc [AOTInductor] Generate kernels separately for const graph and main graph (#153040)
Summary:
We should generate the kernel for const graph and main graph separately.
The reason is that when we run autotuning, we would create separate
kernel calls and we should make sure that main graph also contains the
runner.

Test Plan:
python test/inductor/test_aot_inductor.py -k test_autotune_with_constant_folding

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D74347765](https://our.internmc.facebook.com/intern/diff/D74347765)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153040
Approved by: https://github.com/angelayi
2025-05-08 18:45:45 +00:00
PyTorch MergeBot
05326b7e49 Revert "Add runtime asserts to AOTI (#152125)"
This reverts commit 834bc5e414.

Reverted https://github.com/pytorch/pytorch/pull/152125 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/152125#issuecomment-2863554139))
2025-05-08 15:58:18 +00:00
Shangdi Yu
834bc5e414 Add runtime asserts to AOTI (#152125)
Summary:
Solves https://github.com/pytorch/pytorch/issues/151925

Currently, AOTI only generate runtime asserts for unbacked symints. We should generate asserts for all `_assert_scalar` calls in the input graph.

Also factored out the run time assertion logic to a separate function.

        We need to generate runtime asserts directly in Inductor instead
        of just re-using the asserts from input graphs becase we reuse the
        same ShapeEnv as before. In particular, on subsequent graph passes,
        we would immediately turn all of these assertions into noops,
        because when we evaluated their expressions, we would see that
        because we had a deferred runtime assert in the ShapeEnv, we
        know "oh, of course this expression is True" already.
        One example is below:
```
        class Model(torch.nn.Module):
            def forward(self, a, b, c):
                nz = torch.nonzero(a)
                ones = a.new_ones([nz.size(0), b.size(0)])
                torch._check(ones.size(0) >= 1)
                equals = torch.add(ones, c)
                return equals
        torch._dynamo.mark_dynamic(c, 0)
```
        When we re-use the ShapeEnv in Inductor lowering, the check that checks
        a and nonzero have the same shape would be evaluted to True after we resolve
        unbacked bindings using the ShapeEnv.
        See test_unbacked_equals_input_size_runtime_assertion in test_aot_inductor.

        In addition to the Inductor generated runtime asserts, we also
        need the runtime asserts from the input graph, because some derived
        runtime asserts are not generated in Inductor. One example is
        below:
```
        class Model(torch.nn.Module):
            def forward(self, x):
                y = x.reshape(100, -1).clone()
                y = y + 1
                return y

        dynamic_shapes = {
            "x": {0: torch.export.Dim.DYNAMIC},
        }
        x.shape[0] needs to be a multiple of 100.
```
        See test_aoti_runtime_asserts_backed_symint in test_aot_inductor.

Example:

```
    def forward(self):
        arg0_1: "f32[s35]";

        arg0_1, = fx_pytree.tree_flatten_spec([], self._in_spec)
         # File: /data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/73a672eb896e7996/scripts/shangdiy/__pt__/pt#link-tree/scripts/shangdiy/pt.py:11 in forward, code: y = x.reshape(100, -1).clone()
        sym_size_int: "Sym(s35)" = torch.ops.aten.sym_size.int(arg0_1, 0)

         #
        mod: "Sym(Mod(s35, 100))" = sym_size_int % 100;  sym_size_int = None
        eq_2: "Sym(Eq(Mod(s35, 100), 0))" = mod == 0;  mod = None
        _assert_scalar = torch.ops.aten._assert_scalar.default(eq_2, "Runtime assertion failed for expression Eq(Mod(s35, 100), 0) on node 'eq'");  eq_2 = _assert_scalar = None

         # File: /data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/73a672eb896e7996/scripts/shangdiy/__pt__/pt#link-tree/scripts/shangdiy/pt.py:11 in forward, code: y = x.reshape(100, -1).clone()
        view: "f32[100, (s35//100)]" = torch.ops.aten.reshape.default(arg0_1, [100, -1]);  arg0_1 = None
        clone: "f32[100, (s35//100)]" = torch.ops.aten.clone.default(view);  view = None

         # File: /data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/73a672eb896e7996/scripts/shangdiy/__pt__/pt#link-tree/scripts/shangdiy/pt.py:12 in forward, code: y = y + 1
        add_6: "f32[100, 1]" = torch.ops.aten.add.Tensor(clone, 1);  clone = None
        return (add_6,)
```

Generated cpp code:

```
    auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, 1);
    auto arg0_1 = std::move(inputs[0]);
    auto arg0_1_size = arg0_1.sizes();
    int64_t s35 = arg0_1_size[0];
    inputs.clear();
    auto& kernels = static_cast<AOTInductorModelKernels&>(*this->kernels_.get());
    if (!((s35 % 100L) == 0L)) { throw std::runtime_error("Expected Eq(Mod(s35, 100), 0) to be True but received " + std::to_string(s35)); }
```

Test Plan:
```
buck run fbcode//mode/dev-nosan //caffe2/test/inductor:test_aot_inductor -- -r aoti_runtime_asserts_backed_symint
```

Differential Revision: D73596786

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152125
Approved by: https://github.com/henrylhtsang, https://github.com/jingsh
2025-05-08 00:27:24 +00:00
Blaine Burton Rister
bc11afd41f [Inductor] FX backend via Wrapper IR (#146942)
# Sub-PRs

These PRs contain refactors from the main one. They should be reviewed and merged first.

- https://github.com/pytorch/pytorch/pull/150458
- https://github.com/pytorch/pytorch/pull/152391
- https://github.com/pytorch/pytorch/pull/152587

# Feature

The goals of this PR are twofold.

## Goal 1: Introduce Wrapper IR as an intermediate step in wrapper codegen.

In addition to Triton/C++/Halide kernels, Inductor also generates "wrapper" code which allocates memory and calls the kernels. Originally, this wrapper code was fairly standard Python which resembled a user-written PyTorch program. Over time, various wrapper code generators have been added to accommodate things like AOTInductor, which prefers C++ code for static compilation. This complexity has bled into other parts of the codebase, as we now need if/else statements to choose between Python and C++ macros. (See an example [here](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/ir.py#L5515-L5522).) Since most of these code generation steps are conceptually identical across target languages, it seems reasonable to refactor them into some kind of intermediate representation which can be shared between the various backends. This might also make it easier to develop out-of-tree backends which cannot put their own macros in core Inductor components.

This PR takes some initial steps to formalize Inductor's wrapper codegen by generalizing the existing Memory Planning IR into a fully fledged Wrapper IR. This is pretty much identical to the existing Memory Planning IR, but it supports a richer set of ops for things like kernel definitions and calls. This refactor could help encapsulate wrapper codegen. Ideally, we don't need to worry about direct Python/C++ codegen in the main compiler files such as `ir.py`, and can instead defer these to classes like `PythonWrapperCodegen` and `CppWrapperCpu`, which operate on the Wrapper IR.

## Goal 2: Convert Wrapper IR into FX IR.

One of the main benefits of Wrapper IR is to enable more diverse Inductor backends. This PR introduces a converter from Wrapper IR into [FX IR](https://pytorch.org/docs/stable/fx.html), which is the intermediate representation most commonly used in PyTorch graph compilers. The purpose of this is to enable out-of-tree backends to consume Inductor's output in FX IR, which would hopefully make Inductor easier to leverage in novel compilers, hardware accelerators, etc.

It's not trivial to generate Python or C++ code which Inductor can compile and run, and doing so may require changes to other core Inductor files, for the reasons outlined in the previous section. The goal of supporting FX output is to enable something like `torch.compile`'s [custom backend](https://pytorch.org/docs/stable/torch.compiler_custom_backends.html) system, in which an out-of-tree backend can receive an optimized FX graph from Inductor, and compile and run it however it likes.

The typical users of this feature would likely not be part of PyTorch, and may or may not support running a kernel in eager mode. However, they can understand what `torch.empty_strided` means, compile and run Triton kernels, etc. So we just need to present them with an FX graph saying what code Inductor wants to run, which should be easier to analyze and transform in a third party system than Python or C++ source.

Since FX IR is fairly stable, this mechanism should hopefully isolate third-party backends, hardware accelerators, etc. from the implementation details of Inductor, and vice versa.

# Current status

Things that seem to work:

- Converted a lot of the most common Python codegen lines to Wrapper IR lines.
     - Handled the following cases, in addition to what was already in the Memory Planning IR:
         - Comments
         - Triton kernels
         - Extern/fallback kernels
         - Freeing tensors (`del buf0`)
         - MultiOutput
         - Graph outputs
         - ReinterpretView / StorageBox, for both call args and outputs.
     - FX conversion asserts that the program only contains Wrapper IR lines, and not strings of Python/C++ code.
- Prototype FX converter which can handle some of the most common use cases.
   - Defining Triton kernels, and putting them in a side table using TorchDynamo's existing [utilities](https://dev-discuss.pytorch.org/t/higher-order-operators-2023-10/1565).
   - Calling wrapped Triton kernels.
   - Calling extern kernels and certain types of fallback kernels.
       - Support both `extern_kernels.*` and `aten.*`.
       - Support multi-output kernels like `torch.topk`.
   - Graphs with multiple inputs/outputs.
   - Training i.e. calling `Tensor.backward()` in a compiled function.
   - Graph breaks (training).
- Run the `torch.fx.GraphModule` on GPU using the standard `__call__` method. This makes it easy to test the correctness of FX codegen.

Things that don't work:
- Both Wrapper IR and Wrapper -> FX coverage are currently best effort. There are still features which aren't captured as Wrapper IR lines, and fall back to plain strings. This representation is functionally correct but probably not rich enough to achieve the goals outlined in the previous sections.
         - Fallback kernels seem like the most difficult thing to fully cover, since they each define their own Python/C++ macros that would need to be converted to FX.
         - Size/alignment asserts are currently disabled via the config file. It's possible to generate FX IR for these, but it seems reasonable to defer these sanity checks to a later PR.
         - CommBuffer's and distributed communication are not yet supported. An earlier version of this PR attempted to implement this by calling `empty_strided_p2p`. However, building and testing distributed support seems non-trivial, so it's probably better to defer this.

# Out-of-tree compilers

With this PR, out of tree backends will be able to do further compilation on the FX graphs by subclassing `WrapperFxCodegen` and overriding the `compile_graph` function. This follows the same API as torch.compile's [custom backends](https://pytorch.org/docs/stable/torch.compiler_custom_backends.html), where the user simply returns a callable running the graph. The callable need not be a method of `GraphModule` or any other PyTorch class. See an example below.

```
from torch._inductor.codegen.wrapper_fxir import WrapperFxCodegen

class MyCustomBackend(WrapperFxCodegen):
     def compile_graph(self, gm):
         # Add 1 to the graph's outputs
         def compiled_fn(*args):
             return [x + 1 for x in gm.graph.forward(*args)]
         return compiled_fn
```

# Example FX graphs

This section contains some example FX graphs generated by Inductor. The correctness of these graphs was verified against eager mode by calling the corresponding `GraphModule`.

Here's an FX graph calling a basic Triton kernel. Notice how outputs are allocated with `torch.empty_strided`, and the Triton kernel is called by reference to Dynamo's triton side table.
```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((8,), (1,)), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(8,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg1_1, in_ptr1: %arg0_1, out_ptr0: %buf0, xnumel: 8, XBLOCK: 8}})
    return (buf0,)
```

Here's a more complicated graph that calls a `torch.addmm` extern kernel.

```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=2] = placeholder[target=arg1_1]
    %buf0 : [num_users=3] = call_function[target=torch.empty_strided](args = ((), ()), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(1,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg1_1, out_ptr0: %buf0, xnumel: 1, r0_numel: 129, XBLOCK: 1}})
    %buf2 : [num_users=2] = call_function[target=torch.empty_strided](args = ((129, 1), (1, 1)), kwargs = {dtype: torch.float32, device: cuda:0})
    %addmm : [num_users=0] = call_function[target=torch.addmm](args = (%buf0, %arg0_1, %arg1_1), kwargs = {alpha: 1, beta: 1, out: %buf2})
    %delete : [num_users=0] = call_function[target=torch._inductor.codegen.wrapper_fxir.delete](args = (%buf0,), kwargs = {})
    return (buf2,)
```

Here's a graph which indexes into a tuple using `operator.getitem`. This is necessary to use the output of the `torch.topk` operation.

```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %buf0 : [num_users=3] = call_function[target=torch.ops.aten.topk.default](args = (%arg0_1, 2), kwargs = {})
    %buf1 : [num_users=2] = call_function[target=operator.getitem](args = (%buf0, 0), kwargs = {})
    %buf2 : [num_users=2] = call_function[target=operator.getitem](args = (%buf0, 1), kwargs = {})
    %delete : [num_users=0] = call_function[target=torch._inductor.codegen.wrapper_fxir.delete](args = (%buf0,), kwargs = {})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(2,)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf1, xnumel: 2, XBLOCK: 2}})
    %triton_kernel_wrapper_mutation_1 : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 1, constant_args_idx: 1, grid: [(2,)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf2, xnumel: 2, XBLOCK: 2}})
    return (buf1, buf2)
```

Here's a graph that reinterprets an output tensor using `torch.as_strided`. This is one way to handle Inductor's `ReinterpretView` op.

```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((2, 4), (4, 1)), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(8,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg0_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: 8, XBLOCK: 8}})
    %buf0_view_buf0_0 : [num_users=1] = call_function[target=torch.as_strided](args = (%buf0, (8,), (1,), 0), kwargs = {})
    return (buf0_view_buf0_0,)
```

Here's a graph with dynamic shapes. This one is a little bit funky. Inductor provides a graph input for each shape symbol, which we map to a placeholder, in this example `s6`. Then, shape expressions in the generated code can refer to the symbol `s6`. The size hint for `s6` is stored in `node.meta["val"]` where `node` is the placeholder defining it. This works out in the generated python code because the placeholder defines a Python variable with the name `s6`.
```
graph():
    %s6 : [num_users=0] = placeholder[target=s6]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((s6,), (1,)), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [[-(((-s6)//8)), 1, 1]], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg2_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: s6, XBLOCK: 8}})
    return buf0
```

Here's another graph, this time with dynamic shapes and strides. The grid expression is more complex since the numel is a product of dimensions.
```
graph():
    %s10 : [num_users=0] = placeholder[target=s10]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ([s10, s10], [s10, 1]), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [[-(((s10**2)//(-64))), 1, 1]], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg2_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: s10**2, XBLOCK: 64}})
    return buf0
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146942
Approved by: https://github.com/jansel
2025-05-06 10:06:39 +00:00
PyTorch MergeBot
99dac7005f Revert "[Inductor] FX backend via Wrapper IR (#146942)"
This reverts commit a7691140a0.

Reverted https://github.com/pytorch/pytorch/pull/146942 on behalf of https://github.com/malfet due to Looks like it indeed breaks lint, see a7691140a0/1 ([comment](https://github.com/pytorch/pytorch/pull/146942#issuecomment-2852192778))
2025-05-05 20:01:29 +00:00
Blaine Burton Rister
a7691140a0 [Inductor] FX backend via Wrapper IR (#146942)
# Sub-PRs

These PRs contain refactors from the main one. They should be reviewed and merged first.

- https://github.com/pytorch/pytorch/pull/150458
- https://github.com/pytorch/pytorch/pull/152391
- https://github.com/pytorch/pytorch/pull/152587

# Feature

The goals of this PR are twofold.

## Goal 1: Introduce Wrapper IR as an intermediate step in wrapper codegen.

In addition to Triton/C++/Halide kernels, Inductor also generates "wrapper" code which allocates memory and calls the kernels. Originally, this wrapper code was fairly standard Python which resembled a user-written PyTorch program. Over time, various wrapper code generators have been added to accommodate things like AOTInductor, which prefers C++ code for static compilation. This complexity has bled into other parts of the codebase, as we now need if/else statements to choose between Python and C++ macros. (See an example [here](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/ir.py#L5515-L5522).) Since most of these code generation steps are conceptually identical across target languages, it seems reasonable to refactor them into some kind of intermediate representation which can be shared between the various backends. This might also make it easier to develop out-of-tree backends which cannot put their own macros in core Inductor components.

This PR takes some initial steps to formalize Inductor's wrapper codegen by generalizing the existing Memory Planning IR into a fully fledged Wrapper IR. This is pretty much identical to the existing Memory Planning IR, but it supports a richer set of ops for things like kernel definitions and calls. This refactor could help encapsulate wrapper codegen. Ideally, we don't need to worry about direct Python/C++ codegen in the main compiler files such as `ir.py`, and can instead defer these to classes like `PythonWrapperCodegen` and `CppWrapperCpu`, which operate on the Wrapper IR.

## Goal 2: Convert Wrapper IR into FX IR.

One of the main benefits of Wrapper IR is to enable more diverse Inductor backends. This PR introduces a converter from Wrapper IR into [FX IR](https://pytorch.org/docs/stable/fx.html), which is the intermediate representation most commonly used in PyTorch graph compilers. The purpose of this is to enable out-of-tree backends to consume Inductor's output in FX IR, which would hopefully make Inductor easier to leverage in novel compilers, hardware accelerators, etc.

It's not trivial to generate Python or C++ code which Inductor can compile and run, and doing so may require changes to other core Inductor files, for the reasons outlined in the previous section. The goal of supporting FX output is to enable something like `torch.compile`'s [custom backend](https://pytorch.org/docs/stable/torch.compiler_custom_backends.html) system, in which an out-of-tree backend can receive an optimized FX graph from Inductor, and compile and run it however it likes.

The typical users of this feature would likely not be part of PyTorch, and may or may not support running a kernel in eager mode. However, they can understand what `torch.empty_strided` means, compile and run Triton kernels, etc. So we just need to present them with an FX graph saying what code Inductor wants to run, which should be easier to analyze and transform in a third party system than Python or C++ source.

Since FX IR is fairly stable, this mechanism should hopefully isolate third-party backends, hardware accelerators, etc. from the implementation details of Inductor, and vice versa.

# Current status

Things that seem to work:

- Converted a lot of the most common Python codegen lines to Wrapper IR lines.
     - Handled the following cases, in addition to what was already in the Memory Planning IR:
         - Comments
         - Triton kernels
         - Extern/fallback kernels
         - Freeing tensors (`del buf0`)
         - MultiOutput
         - Graph outputs
         - ReinterpretView / StorageBox, for both call args and outputs.
     - FX conversion asserts that the program only contains Wrapper IR lines, and not strings of Python/C++ code.
- Prototype FX converter which can handle some of the most common use cases.
   - Defining Triton kernels, and putting them in a side table using TorchDynamo's existing [utilities](https://dev-discuss.pytorch.org/t/higher-order-operators-2023-10/1565).
   - Calling wrapped Triton kernels.
   - Calling extern kernels and certain types of fallback kernels.
       - Support both `extern_kernels.*` and `aten.*`.
       - Support multi-output kernels like `torch.topk`.
   - Graphs with multiple inputs/outputs.
   - Training i.e. calling `Tensor.backward()` in a compiled function.
   - Graph breaks (training).
- Run the `torch.fx.GraphModule` on GPU using the standard `__call__` method. This makes it easy to test the correctness of FX codegen.

Things that don't work:
- Both Wrapper IR and Wrapper -> FX coverage are currently best effort. There are still features which aren't captured as Wrapper IR lines, and fall back to plain strings. This representation is functionally correct but probably not rich enough to achieve the goals outlined in the previous sections.
         - Fallback kernels seem like the most difficult thing to fully cover, since they each define their own Python/C++ macros that would need to be converted to FX.
         - Size/alignment asserts are currently disabled via the config file. It's possible to generate FX IR for these, but it seems reasonable to defer these sanity checks to a later PR.
         - CommBuffer's and distributed communication are not yet supported. An earlier version of this PR attempted to implement this by calling `empty_strided_p2p`. However, building and testing distributed support seems non-trivial, so it's probably better to defer this.

# Out-of-tree compilers

With this PR, out of tree backends will be able to do further compilation on the FX graphs by subclassing `WrapperFxCodegen` and overriding the `compile_graph` function. This follows the same API as torch.compile's [custom backends](https://pytorch.org/docs/stable/torch.compiler_custom_backends.html), where the user simply returns a callable running the graph. The callable need not be a method of `GraphModule` or any other PyTorch class. See an example below.

```
from torch._inductor.codegen.wrapper_fxir import WrapperFxCodegen

class MyCustomBackend(WrapperFxCodegen):
     def compile_graph(self, gm):
         # Add 1 to the graph's outputs
         def compiled_fn(*args):
             return [x + 1 for x in gm.graph.forward(*args)]
         return compiled_fn
```

# Example FX graphs

This section contains some example FX graphs generated by Inductor. The correctness of these graphs was verified against eager mode by calling the corresponding `GraphModule`.

Here's an FX graph calling a basic Triton kernel. Notice how outputs are allocated with `torch.empty_strided`, and the Triton kernel is called by reference to Dynamo's triton side table.
```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((8,), (1,)), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(8,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg1_1, in_ptr1: %arg0_1, out_ptr0: %buf0, xnumel: 8, XBLOCK: 8}})
    return (buf0,)
```

Here's a more complicated graph that calls a `torch.addmm` extern kernel.

```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=2] = placeholder[target=arg1_1]
    %buf0 : [num_users=3] = call_function[target=torch.empty_strided](args = ((), ()), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(1,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg1_1, out_ptr0: %buf0, xnumel: 1, r0_numel: 129, XBLOCK: 1}})
    %buf2 : [num_users=2] = call_function[target=torch.empty_strided](args = ((129, 1), (1, 1)), kwargs = {dtype: torch.float32, device: cuda:0})
    %addmm : [num_users=0] = call_function[target=torch.addmm](args = (%buf0, %arg0_1, %arg1_1), kwargs = {alpha: 1, beta: 1, out: %buf2})
    %delete : [num_users=0] = call_function[target=torch._inductor.codegen.wrapper_fxir.delete](args = (%buf0,), kwargs = {})
    return (buf2,)
```

Here's a graph which indexes into a tuple using `operator.getitem`. This is necessary to use the output of the `torch.topk` operation.

```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %buf0 : [num_users=3] = call_function[target=torch.ops.aten.topk.default](args = (%arg0_1, 2), kwargs = {})
    %buf1 : [num_users=2] = call_function[target=operator.getitem](args = (%buf0, 0), kwargs = {})
    %buf2 : [num_users=2] = call_function[target=operator.getitem](args = (%buf0, 1), kwargs = {})
    %delete : [num_users=0] = call_function[target=torch._inductor.codegen.wrapper_fxir.delete](args = (%buf0,), kwargs = {})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(2,)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf1, xnumel: 2, XBLOCK: 2}})
    %triton_kernel_wrapper_mutation_1 : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 1, constant_args_idx: 1, grid: [(2,)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf2, xnumel: 2, XBLOCK: 2}})
    return (buf1, buf2)
```

Here's a graph that reinterprets an output tensor using `torch.as_strided`. This is one way to handle Inductor's `ReinterpretView` op.

```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((2, 4), (4, 1)), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(8,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg0_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: 8, XBLOCK: 8}})
    %buf0_view_buf0_0 : [num_users=1] = call_function[target=torch.as_strided](args = (%buf0, (8,), (1,), 0), kwargs = {})
    return (buf0_view_buf0_0,)
```

Here's a graph with dynamic shapes. This one is a little bit funky. Inductor provides a graph input for each shape symbol, which we map to a placeholder, in this example `s6`. Then, shape expressions in the generated code can refer to the symbol `s6`. The size hint for `s6` is stored in `node.meta["val"]` where `node` is the placeholder defining it. This works out in the generated python code because the placeholder defines a Python variable with the name `s6`.
```
graph():
    %s6 : [num_users=0] = placeholder[target=s6]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((s6,), (1,)), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [[-(((-s6)//8)), 1, 1]], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg2_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: s6, XBLOCK: 8}})
    return buf0
```

Here's another graph, this time with dynamic shapes and strides. The grid expression is more complex since the numel is a product of dimensions.
```
graph():
    %s10 : [num_users=0] = placeholder[target=s10]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ([s10, s10], [s10, 1]), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [[-(((s10**2)//(-64))), 1, 1]], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg2_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: s10**2, XBLOCK: 64}})
    return buf0
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146942
Approved by: https://github.com/jansel
2025-05-05 19:34:49 +00:00
rzou
2b37a726e0 Refactor layout constraint selection logic (#148104)
This PR:

- cleans up some existing comments that don't make sense anymore
- hooks up the "custom_op_default_layout_constraint" back (that seems to
have broken)
- cleans up the "lazy registration path" which seems to never get hit
anymore
- adds dislike_padding to nodes that require exact strides

Test Plan:
- tests + CI

disable padding

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148104
Approved by: https://github.com/shunting314, https://github.com/eellison
2025-05-03 00:02:24 +00:00
Laith Sakka
38a9a8b7f7 Fix: Consider input defined unbacked during inductor codegen for runtime asserts (#152231)
So when we use mark_unbacked the graph will have an unbacked inputs symInt. Right now,
deferred runtime assertions that uses those  is never generated.

This PR changes that, such that in the forward graph we consider those and generate the corresponding
runtime assertions of them. We still ignore them for backward which is not ideal

The way we generate runtime assertion is by emitting them when all the defined unbacked symbols used
in them are seen.

We previously skipped placeholder, because for backward we have a wacky approach were we
ignore input defined unbacked symbols and assumes assertions that uses them are already emitted
in forward and we try to emit all other runtime assertions again. see [Note [Backwards runtime asserts]

Doing that we ends up only emitting the runtime assertions that depends on things defined solely in backward, but we could miss checks that spans inputs defined in both backward and forward, i.e one symbol defined in forward passed as input to backward., and another that is defined in backward.) .This is not ideal an ideal approach could be something like this https://github.com/pytorch/pytorch/pull/151919 but it require more work .

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152231
Approved by: https://github.com/aorenste
2025-05-02 07:01:48 +00:00
Mu-Chu Lee
49cbe0ffe9 [AOTInductor] Propagate ConstantType for main graph. (#152272)
Summary:
We need to make sure all named_parameters and named_buffers be
propagated if we use runtime constant folding.

Test Plan:
python test/inductor/test_aot_inductor.py -k test_constant_type_propagation

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152272
Approved by: https://github.com/22quinn

Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com>
Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
2025-04-29 12:42:17 +00:00
Laith Sakka
55595e0c85 Fix Issues in deferring runtime assertions. (#151170)
This PR fix two bugs:
1)  Update self.bound_unbacked_symbols before emitting runtime asserts :
set self.bound_unbacked_symbols before emitting runtime asserts to include runtime asserts depending on the current node

2) In the pass that remove unused graph inputs, we should not remove symbols that are used by runtime assertions.

Address some of the issues in https://github.com/pytorch/pytorch/issues/151127

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151170
Approved by: https://github.com/bobrenjc93, https://github.com/eellison
2025-04-16 08:10:17 +00:00
Bert Maher
2d187bf7e6 Support tuning of _scaled_grouped_mm (#150421)
This includes the default aten implementation, as well as a Triton
implementation imported from FBGEMM
(https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150421
Approved by: https://github.com/ngimel
2025-04-11 23:03:49 +00:00
Yiming Zhou
dbcd0b571d Back out "[AOTI] Always use oss schema for ExternKernelNodes serialization" (#151026)
Summary: Revert for FC breaking

Test Plan: CI

Differential Revision: D72802075

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151026
Approved by: https://github.com/hl475
2025-04-10 22:36:35 +00:00
PyTorch MergeBot
6a65f2c4fe Revert "Support tuning of _scaled_grouped_mm (#150421)"
This reverts commit 8efcf21fff.

Reverted https://github.com/pytorch/pytorch/pull/150421 on behalf of https://github.com/malfet due to Looks like it broke lint, see a0ab243c3a/1 ([comment](https://github.com/pytorch/pytorch/pull/150421#issuecomment-2795218547))
2025-04-10 21:36:41 +00:00
Bert Maher
8efcf21fff Support tuning of _scaled_grouped_mm (#150421)
This includes the default aten implementation, as well as a Triton
implementation imported from FBGEMM
(https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150421
Approved by: https://github.com/ngimel
2025-04-10 20:34:16 +00:00
PyTorch MergeBot
01568cb17a Revert "Refactor layout constraint selection logic (#148104)"
This reverts commit 2e7c9d33e7.

Reverted https://github.com/pytorch/pytorch/pull/148104 on behalf of https://github.com/atalman due to [GH job link](https://github.com/pytorch/pytorch/actions/runs/14357056427/job/40251630946) [HUD commit link](2e7c9d33e7) ([comment](https://github.com/pytorch/pytorch/pull/148104#issuecomment-2790369493))
2025-04-09 16:49:48 +00:00
Xia, Weiwen
246f3b6530 [Quant][PT2E][X86] enable qconv1d-relu fusion (#150751)
**Summary**
As the title.
- The `conv1d - relu` pattern will be annotated by the `X86InductorQuantizer`.
- The pattern will be fused as `qconv_pointwise` during lowering.

**Test plan**
```
python test/inductor/test_mkldnn_pattern_matcher.py -k test_qconv1d_relu_cpu
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150751
Approved by: https://github.com/jerryzh168, https://github.com/leslie-fang-intel
2025-04-09 14:42:02 +00:00
rzou
2e7c9d33e7 Refactor layout constraint selection logic (#148104)
This PR:

- cleans up some existing comments that don't make sense anymore
- hooks up the "custom_op_default_layout_constraint" back (that seems to
have broken)
- cleans up the "lazy registration path" which seems to never get hit
anymore
- adds dislike_padding to nodes that require exact strides

Test Plan:
- tests + CI

disable padding

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148104
Approved by: https://github.com/shunting314, https://github.com/eellison
ghstack dependencies: #150495
2025-04-09 02:09:18 +00:00
Yiming Zhou
89505f4498 [AOTI] Always use oss schema for ExternKernelNodes serialization (#150197)
Summary: Added a field `protocol` to `ExternKernelNodes` and all the lowering pass will always use the oss schema to serialize external kernel nodes from now on.

Test Plan: CI

Differential Revision: D72020444

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150197
Approved by: https://github.com/zhxchen17
2025-04-08 22:35:28 +00:00
Shunting Zhang
901b02cf16 [Inductor] fix alignement assumption for fallback (#150777)
Inductor right now only works properly for fallback kernels producing aligned output.
When Inductor create layout for fallback kernel output, Inductor does not add the tensor offset to the layout [link](2a1e2b88ed/torch/_inductor/ir.py (L6935-L6941)). Thus unaligned output will be treated as aligned. Adding the offset to the layout directly does not work since that change the index expression in the generated kernel and we may 'double' applying the offset. Triton already considers the offset when passing in the data_ptr.

To solve this issue, we track the unaligned buffer names instead.

This potentially can fix the internal issues we are debugging here: https://fb.workplace.com/groups/1075192433118967/permalink/1618308128807392/

Differential Revision: [D72600784](https://our.internmc.facebook.com/intern/diff/D72600784)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150777
Approved by: https://github.com/eellison, https://github.com/jansel
2025-04-08 18:49:44 +00:00
rzou
aae36929ed Rename node.meta["arg_kwarg_vals"] to node.meta["eager_input_vals"] (#148092)
And added a comment about it. Otherwise it might be confusing

Test Plan:
- wait for CI

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148092
Approved by: https://github.com/eellison
ghstack dependencies: #148046, #148063, #148091
2025-04-02 13:18:04 +00:00
rzou
c69c3c885e Add needs_exact_strides operator tag for Inductor to force exact strides (#148063)
Inductor will force exact strides on a custom operator tagged with
needs_exact_strides. I'll make this the default in a follow-up PR.

Test Plan:
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148063
Approved by: https://github.com/eellison
ghstack dependencies: #148046
2025-04-02 13:17:58 +00:00
Shangdi Yu
cc58ecceea Move dump location to avoid dumping twice (#150219)
Summary:
If we put the dumping code in codegen, we might get a separate node_mapping dump for the constant folded graph (https://github.com/pytorch/pytorch/blob/main/torch/_inductor/compile_fx.py#L1119).

We move it into compile_fx.py so there's only one node_mapping dump.

Test Plan: CI

Reviewed By: YUNQIUGUO

Differential Revision: D72068715

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150219
Approved by: https://github.com/YUNQIUGUO
2025-03-30 03:35:38 +00:00
Mu-Chu Lee
a0253d2840 [Inductor] Use real input to autotune user defined triton kernels (#149553)
Summary:
User defined Triton kernel sometimes rely on real inputs to determine
the path of execution. We need real inputs to invoke the correct
behavior of the user defined triton kernels (see example in test case,
where we have an early return for random inputs)

Test Plan:
Included in the commit.
python test/inductor/test_aot_inductor.py -k triton_autotuning
python test/inductor/test_aot_inductor.py -k triton_mutated_autotuning

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149553
Approved by: https://github.com/davidberard98, https://github.com/eellison
2025-03-26 16:42:48 +00:00
Shangdi Yu
46dd226702 Fakify torchbind objects in compile_fx and add tests for SigridTransformsInstanceTorchBind (#149529)
Summary:
We need to properly fakify torchbind objects, including the ones in graph module attributes, so the resgitered fake implementation works properly.

- _fakify_script_objects in `compile_fx`
- Allow fake torchbind objects in `torchbind_constants`

Remove `node.meta["unbacked_bindings"]` for `aot_compile` in `compile_fx`. Otherwise `ShapeProp` will fail when trying to resolve the `unbacked_bindings` of `with_effect` tokens.

Update `sigrid_transforms_test` to use the latest `torch._inductor.aot_compile` API.

Add a test for `Fakify torchbind objects in compile_fx and add tests for SigridTransformsInstanceTorchBind` in `e2e_test`.

Test Plan:
```
buck run //caffe2/torch/fb/sparsenn:sigrid_test -- -r test_transform_torch_bind

buck run //sigmoid/inference/test:e2e_test_cpu -- -r SigridTransforms

buck2 run mode/dev-nosan sigmoid/inference/ts_migration:pt2i_readiness_main -- --model_id 545017754 --test_suite ads_all --mode test_preproc

```

Differential Revision: D70013257

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149529
Approved by: https://github.com/angelayi
2025-03-21 18:58:28 +00:00
Boyuan Feng
3e605fe46d [CUDAGraph] Graph Partition (#147648)
This PR implements cudagraph partition, following previous PR on inductor graph partition (#147038). Since there are many ops that cudagraph cannot support, this PR focuses on `cpu ops` and will add more partition rules in the next PR.

## Example
```python
import torch

torch._inductor.config.graph_partition = True

def f(x, y):
    x1 = x + 1
    y1 = y + 1
    y_cpu = y1.cpu() + 1
    z = x @ y
    return x1 + y1 + z + y_cpu.cuda()

x, y = [torch.ones(2, 2, device="cuda") for _ in range(2)]
x_cloned, y_cloned = [tmp.clone() for tmp in [x,y]]
eager_out = f(x, y)

f_compiled = torch.compile(f, mode="reduce-overhead")

for _ in range(5):
    compiled_out = f_compiled(x_cloned, y_cloned)
    assert torch.allclose(eager_out, compiled_out)
```

w/o graph partition, we will skip cudagraph:
```
skipping cudagraphs due to skipping cudagraphs due to cpu device (device_put). Found from :
   File "/home/boyuan/playground/cudagraph/graph_partition/graph_partition.py", line 9, in f
    y_cpu = y1.cpu() + 1 # 3
```

w/ graph partition, we can see two cudagraphify under the same torch-compiled region:
![image](https://github.com/user-attachments/assets/4e22d428-2687-433d-b92a-0814a2201b25)

## Design

PR #147038 splits `def call(args)` function into multiple `def partition_id(args)`. In this PR, we use `recursively_apply_fns()` to wrap each `partition_id()` function with `cudagraphify`. One major design point is, `cudagraphify` takes metadata such as static_input_idxs and we need to provide such metadata for each graph partition. However, we previously only have such metadata for the original graph instead of graph partitions.

The [idea](https://github.com/pytorch/pytorch/pull/147038#discussion_r1964124800) is:
- compute a mapping from the partition metadata (e.g., input/output idx) to the graph metadata, stored in `GraphPartitionMap`.
- during post_compile, get the `CudagraphMetadata` for each partition based on the graph-level metadata and `GraphPartitionMap`, via `get_partition_cudagraph_metadata()`.
- finally, in `cudagraph_partition_pos_compile`, we compute the `CudagraphMetadata` and apply cudagraphify for each graph via `recursively_apply_fns`.

#### Q: How does it work with codecache?

While we have multiple graph partitions, we still have 1 file and 1 `call` function for 1 dynamo graph. The major difference is we need to additionally load a `recursively_apply_fns()` for graph partition. We also add `partition_maps: Optional[list[GraphPartitionMap]]` to `CompiledFxGraph` so it will be serialized and could be deserialized later.

## Edge Case 1
PyTorch has an assumption on input/output orders. For example, backward inputs take saved tensors first and then tangents. In graph partition, we respect such orders via `graph_partition_signature_reorder`.

## Edge Case 2
Cudagraphifying `call` function gives 2 cudagraph managed tensors `buf0` and `primals_1`. However, cudagraphifying `partition_0` gives only 1 cudagraph managed tensor `buf0`. This leads to a semantic difference between cudagraph w/ and w/o graph partition. [full code comparison](https://www.internalfb.com/intern/diffing/?paste_number=1747654420)

![image](https://github.com/user-attachments/assets/03d08ce0-f1d1-4d1d-8432-805a07e1dd40)

To achieve the same semantic, we returns an input tensor as output if it is not freed in a graph partition. This allows more cudagraph managed tensors and is important for handling saved tensors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147648
Approved by: https://github.com/eellison
2025-03-13 16:00:21 +00:00
FFFrog
416ea1c71c Code Clean: Remove unnecessary code (#148735)
As the title stated.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148735
Approved by: https://github.com/jingsh, https://github.com/cyyever
2025-03-07 08:15:37 +00:00
Shunting Zhang
6cc3e69103 [inductor] use eager stride for custom op if no tags (#148367)
Fix https://github.com/pytorch/pytorch/issues/148356

This is some sort of short term fix to recover the default behavior to apply layout constraint for custom ops when there are no tags.

A longer term attempt to make sure Inductor always gets correct eager strides is here: https://github.com/pytorch/pytorch/pull/148104

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148367
Approved by: https://github.com/eellison, https://github.com/zou3519
2025-03-06 00:58:00 +00:00
Benjamin Glass
d6d670ab4d [AOTI] build CPU CPP kernels at O3, and all other code at O1 (#148587)
In the future, we may also want to add LTO linking to further optimize the results (while still hopefully netting compile time benefits).

Differential Revision: [D70641543](https://our.internmc.facebook.com/intern/diff/D70641543)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148587
Approved by: https://github.com/desertfire
2025-03-05 22:47:46 +00:00
Animesh Jain
fd16311e7f [inductor][subgraph] Plumbing to get ShapeAsConstantBuffer from subgraph to main graph output (#147559)
I am unable to create a test case that fails without the next PR. The idea is to have a symint which is returned by the inner subgraph and then returned by the forward graph after partitioning.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147559
Approved by: https://github.com/eellison
2025-03-01 06:17:11 +00:00
Xuehai Pan
1cb4e2df65 [BE][PYFMT] migrate PYFMT for torch._inductor to ruff format (#144550)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144550
Approved by: https://github.com/jansel
2025-02-28 13:33:19 +00:00
eellison
481a57bc37 Support torch.compile rng selective activation checkpointing with cudagraph (#146878)
TODO:
- [x]  Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync
- [x] Update rng state initialization to take from correct device
- [x]  Tests
- [x] handling of retain_graph
- [x] respect fallback random

Fix for https://github.com/pytorch/pytorch/issues/130123.

Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states.

We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward.

```
 ===== Forward graph 1 =====
 /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0);  fwd_rng_state_0 = None
        ...

 ===== Backward graph 1 =====
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0);  bwd_rng_state_0 = None
```

There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls:
- fwd0: fwd_rng_state0 -> fwd_rng_state1
- fwd1: fwd_rng_state1 -> fwd_rng_state2
- bwd1
- bwd0

Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary.

Other notes:

Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order.

Questions for reviewers:

This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`.

Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set

I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts.

Edit: updated to be taken from randint()

Update: initializing rng states from torch.randint..

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
2025-02-28 00:47:03 +00:00
Boyuan Feng
b6fe28ff02 [Inductor] Graph Partition (#147038)
This PR implements inductor graph partition. Previously, 1 dynamo graph is mapped to 1 inductor graph, and further mapped to 1 call function. In this PR, we allow 1 dynamo graph mapped to multiple inductor graphs and multiple `graph_partition` functions in the generated code. This allows applying different further optimizations to different `graph_partition`.

Design Doc: [link](https://docs.google.com/document/d/1qPgOfy25l7SIYnrQrvU-TO1mdHMslCwv_SLmeXID6tM/edit?usp=sharing)
Example: [Generated code before and after this diff](https://www.internalfb.com/intern/diffing/?paste_number=1737334601)

In the follow-up PR, we will extend the work to cudagraph, which allows applying cudagraph to parts of the generated code (#125864).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147038
Approved by: https://github.com/eellison
2025-02-27 04:50:43 +00:00
PyTorch MergeBot
17358ce778 Revert "Support torch.compile rng selective activation checkpointing with cudagraph (#146878)"
This reverts commit ad0c879e22.

Reverted https://github.com/pytorch/pytorch/pull/146878 on behalf of https://github.com/wdvr due to lint failure ([comment](https://github.com/pytorch/pytorch/pull/146878#issuecomment-2686767956))
2025-02-27 03:36:16 +00:00