pytorch/torch/_inductor/codegen
Adnan Akhundov 0bed0501fa Don't skip register-spilling configs in custom Triton kernel auto-tuning (#119634)
Summary: There has been some empirical evidence that, for (non-trivial) custom (user-written) Triton kernels, a register-spilling config yields the best result in auto-tuning. For this reason, we don't skip register-spilling config from auto-tuning of the custom Triton kernels.

<details>
<summary>An example of auto-tuning result with the register-spilling config outperforming others</summary>

```
BLOCK_M: 16, BLOCK_N: 16, num_warps: 2, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.748896, nreg 255, nspill 0, #shared-mem 8704
BLOCK_M: 16, BLOCK_N: 16, num_warps: 4, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 1.723424, nreg 249, nspill 0, #shared-mem 8704
BLOCK_M: 16, BLOCK_N: 16, num_warps: 8, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 2.202656, nreg 190, nspill 0, #shared-mem 8704
BLOCK_M: 16, BLOCK_N: 16, num_warps: 2, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.748256, nreg 255, nspill 0, #shared-mem 8704
BLOCK_M: 16, BLOCK_N: 16, num_warps: 4, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 1.724896, nreg 249, nspill 0, #shared-mem 8704
BLOCK_M: 16, BLOCK_N: 16, num_warps: 8, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 2.201632, nreg 190, nspill 0, #shared-mem 8704
BLOCK_M: 16, BLOCK_N: 32, num_warps: 2, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.651664, nreg 255, nspill 56, #shared-mem 13312
BLOCK_M: 16, BLOCK_N: 32, num_warps: 4, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.846368, nreg 255, nspill 14, #shared-mem 13312
BLOCK_M: 16, BLOCK_N: 32, num_warps: 8, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 1.841792, nreg 243, nspill 0, #shared-mem 13312
BLOCK_M: 16, BLOCK_N: 32, num_warps: 2, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.651584, nreg 255, nspill 56, #shared-mem 13312
BLOCK_M: 16, BLOCK_N: 32, num_warps: 4, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.846432, nreg 255, nspill 14, #shared-mem 13312
BLOCK_M: 16, BLOCK_N: 32, num_warps: 8, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 1.841904, nreg 243, nspill 0, #shared-mem 13312
BLOCK_M: 16, BLOCK_N: 64, num_warps: 2, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 1.236448, nreg 255, nspill 254, #shared-mem 22528
BLOCK_M: 16, BLOCK_N: 64, num_warps: 4, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 1.484384, nreg 255, nspill 174, #shared-mem 22528
BLOCK_M: 16, BLOCK_N: 64, num_warps: 8, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 1.131168, nreg 255, nspill 6, #shared-mem 22528
BLOCK_M: 16, BLOCK_N: 64, num_warps: 2, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 1.236544, nreg 255, nspill 254, #shared-mem 22528
BLOCK_M: 16, BLOCK_N: 64, num_warps: 4, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 1.483648, nreg 255, nspill 174, #shared-mem 22528
BLOCK_M: 16, BLOCK_N: 64, num_warps: 8, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 1.131408, nreg 255, nspill 6, #shared-mem 22528
BLOCK_M: 32, BLOCK_N: 16, num_warps: 2, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.516112, nreg 255, nspill 28, #shared-mem 13312
BLOCK_M: 32, BLOCK_N: 16, num_warps: 4, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.737792, nreg 255, nspill 0, #shared-mem 13312
BLOCK_M: 32, BLOCK_N: 16, num_warps: 8, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 1.411632, nreg 193, nspill 0, #shared-mem 13312
BLOCK_M: 32, BLOCK_N: 16, num_warps: 2, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.515904, nreg 255, nspill 28, #shared-mem 13312
BLOCK_M: 32, BLOCK_N: 16, num_warps: 4, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.736608, nreg 255, nspill 0, #shared-mem 13312
BLOCK_M: 32, BLOCK_N: 16, num_warps: 8, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 1.409808, nreg 193, nspill 0, #shared-mem 13312
BLOCK_M: 32, BLOCK_N: 32, num_warps: 2, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.553536, nreg 255, nspill 130, #shared-mem 18432
BLOCK_M: 32, BLOCK_N: 32, num_warps: 4, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.569792, nreg 255, nspill 56, #shared-mem 18432
BLOCK_M: 32, BLOCK_N: 32, num_warps: 8, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.892448, nreg 255, nspill 4, #shared-mem 18432
BLOCK_M: 32, BLOCK_N: 32, num_warps: 2, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.553584, nreg 255, nspill 130, #shared-mem 18432
BLOCK_M: 32, BLOCK_N: 32, num_warps: 4, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.569568, nreg 255, nspill 56, #shared-mem 18432
BLOCK_M: 32, BLOCK_N: 32, num_warps: 8, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.892240, nreg 255, nspill 4, #shared-mem 18432
BLOCK_M: 32, BLOCK_N: 64, num_warps: 2, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 1.332928, nreg 255, nspill 366, #shared-mem 28672
BLOCK_M: 32, BLOCK_N: 64, num_warps: 4, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.922256, nreg 255, nspill 228, #shared-mem 28672
BLOCK_M: 32, BLOCK_N: 64, num_warps: 8, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.758400, nreg 255, nspill 26, #shared-mem 28672
BLOCK_M: 32, BLOCK_N: 64, num_warps: 2, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 1.333440, nreg 255, nspill 366, #shared-mem 28672
BLOCK_M: 32, BLOCK_N: 64, num_warps: 4, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.922336, nreg 255, nspill 228, #shared-mem 28672
BLOCK_M: 32, BLOCK_N: 64, num_warps: 8, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.758496, nreg 255, nspill 26, #shared-mem 28672
BLOCK_M: 64, BLOCK_N: 16, num_warps: 2, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 1.231648, nreg 255, nspill 292, #shared-mem 22528
BLOCK_M: 64, BLOCK_N: 16, num_warps: 4, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.639424, nreg 255, nspill 90, #shared-mem 22528
BLOCK_M: 64, BLOCK_N: 16, num_warps: 8, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.917952, nreg 240, nspill 0, #shared-mem 22528
BLOCK_M: 64, BLOCK_N: 16, num_warps: 2, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 1.230624, nreg 255, nspill 292, #shared-mem 22528
BLOCK_M: 64, BLOCK_N: 16, num_warps: 4, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.639168, nreg 255, nspill 90, #shared-mem 22528
BLOCK_M: 64, BLOCK_N: 16, num_warps: 8, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.917440, nreg 240, nspill 0, #shared-mem 22528
BLOCK_M: 64, BLOCK_N: 32, num_warps: 2, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.838080, nreg 255, nspill 354, #shared-mem 28672
BLOCK_M: 64, BLOCK_N: 32, num_warps: 4, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.569184, nreg 255, nspill 178, #shared-mem 28672
BLOCK_M: 64, BLOCK_N: 32, num_warps: 8, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.614720, nreg 255, nspill 28, #shared-mem 28672
BLOCK_M: 64, BLOCK_N: 32, num_warps: 2, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.838048, nreg 255, nspill 354, #shared-mem 28672
BLOCK_M: 64, BLOCK_N: 32, num_warps: 4, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.569472, nreg 255, nspill 178, #shared-mem 28672
BLOCK_M: 64, BLOCK_N: 32, num_warps: 8, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.615104, nreg 255, nspill 28, #shared-mem 28672
BLOCK_M: 64, BLOCK_N: 64, num_warps: 2, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 1.012128, nreg 255, nspill 522, #shared-mem 40960
BLOCK_M: 64, BLOCK_N: 64, num_warps: 4, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.861536, nreg 255, nspill 378, #shared-mem 40960
BLOCK_M: 64, BLOCK_N: 64, num_warps: 8, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.771584, nreg 255, nspill 134, #shared-mem 40960
BLOCK_M: 64, BLOCK_N: 64, num_warps: 2, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 1.012512, nreg 255, nspill 522, #shared-mem 40960
BLOCK_M: 64, BLOCK_N: 64, num_warps: 4, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.861024, nreg 255, nspill 378, #shared-mem 40960
BLOCK_M: 64, BLOCK_N: 64, num_warps: 8, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.771712, nreg 255, nspill 134, #shared-mem 40960
```

</details>

In the above, the winning config is `BLOCK_M: 32, BLOCK_N: 16, num_warps: 2, num_ctas: 1, num_stages: 2`, although it has non-zero `nspill 28`. This is an example where we need to consider all configs, including the register-spilling ones, to obtain the best result from auto-tuning.

In the worst case, this will just make auto-tuning longer, but can't regress the results. And, as the number of custom Triton kernels in the model is normally much smaller than the number of Inductor-generated ones, this should be acceptable.

Test Plan: CI

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119634
Approved by: https://github.com/oulgen
2024-02-11 02:13:25 +00:00
..
aoti_runtime [AOTI][refactor] Split common aoti_runtime utils into a separate header (#119066) 2024-02-07 16:54:00 +00:00
cuda [Inductor max autotune] Multithreaded Precompilation (#119386) 2024-02-09 16:11:30 +00:00
__init__.py
common.py [Inductor] Add Int8 data type into Inductor CPP backend vectorized code generation (#119179) 2024-02-09 07:33:12 +00:00
cpp_prefix.h [Inductor] Add Int8 data type into Inductor CPP backend vectorized code generation (#119179) 2024-02-09 07:33:12 +00:00
cpp.py Revert "[aot_inductor] move CppWrapperCodeGen into a separate file (#119491)" 2024-02-10 23:02:05 +00:00
cuda_combined_scheduling.py [mypy] added type annotations to codegen_nodes methods (#119080) 2024-02-05 18:33:52 +00:00
memory_planning.py Remove follow_imports = skip from sympy (#118469) 2024-01-28 13:38:38 +00:00
multi_kernel.py [inductor] make multi-kernel work with cpp-wrapper (#117813) 2024-02-05 23:35:41 +00:00
triton_foreach.py [inductor] make inductor work with new triton compile interface (#115878) 2023-12-22 00:09:29 +00:00
triton_split_scan.py [inductor] Add split scan kernel (#117992) 2024-02-09 01:56:00 +00:00
triton_utils.py [inductor] Add split scan kernel (#117992) 2024-02-09 01:56:00 +00:00
triton.py [inductor] Fix compile error on scan with no mask (#119555) 2024-02-10 12:38:40 +00:00
wrapper.py Don't skip register-spilling configs in custom Triton kernel auto-tuning (#119634) 2024-02-11 02:13:25 +00:00