Don't skip register-spilling configs in custom Triton kernel auto-tuning (#119634)

Summary: There has been some empirical evidence that, for (non-trivial) custom (user-written) Triton kernels, a register-spilling config yields the best result in auto-tuning. For this reason, we don't skip register-spilling config from auto-tuning of the custom Triton kernels.

<details>
<summary>An example of auto-tuning result with the register-spilling config outperforming others</summary>

```
BLOCK_M: 16, BLOCK_N: 16, num_warps: 2, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.748896, nreg 255, nspill 0, #shared-mem 8704
BLOCK_M: 16, BLOCK_N: 16, num_warps: 4, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 1.723424, nreg 249, nspill 0, #shared-mem 8704
BLOCK_M: 16, BLOCK_N: 16, num_warps: 8, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 2.202656, nreg 190, nspill 0, #shared-mem 8704
BLOCK_M: 16, BLOCK_N: 16, num_warps: 2, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.748256, nreg 255, nspill 0, #shared-mem 8704
BLOCK_M: 16, BLOCK_N: 16, num_warps: 4, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 1.724896, nreg 249, nspill 0, #shared-mem 8704
BLOCK_M: 16, BLOCK_N: 16, num_warps: 8, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 2.201632, nreg 190, nspill 0, #shared-mem 8704
BLOCK_M: 16, BLOCK_N: 32, num_warps: 2, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.651664, nreg 255, nspill 56, #shared-mem 13312
BLOCK_M: 16, BLOCK_N: 32, num_warps: 4, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.846368, nreg 255, nspill 14, #shared-mem 13312
BLOCK_M: 16, BLOCK_N: 32, num_warps: 8, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 1.841792, nreg 243, nspill 0, #shared-mem 13312
BLOCK_M: 16, BLOCK_N: 32, num_warps: 2, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.651584, nreg 255, nspill 56, #shared-mem 13312
BLOCK_M: 16, BLOCK_N: 32, num_warps: 4, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.846432, nreg 255, nspill 14, #shared-mem 13312
BLOCK_M: 16, BLOCK_N: 32, num_warps: 8, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 1.841904, nreg 243, nspill 0, #shared-mem 13312
BLOCK_M: 16, BLOCK_N: 64, num_warps: 2, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 1.236448, nreg 255, nspill 254, #shared-mem 22528
BLOCK_M: 16, BLOCK_N: 64, num_warps: 4, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 1.484384, nreg 255, nspill 174, #shared-mem 22528
BLOCK_M: 16, BLOCK_N: 64, num_warps: 8, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 1.131168, nreg 255, nspill 6, #shared-mem 22528
BLOCK_M: 16, BLOCK_N: 64, num_warps: 2, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 1.236544, nreg 255, nspill 254, #shared-mem 22528
BLOCK_M: 16, BLOCK_N: 64, num_warps: 4, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 1.483648, nreg 255, nspill 174, #shared-mem 22528
BLOCK_M: 16, BLOCK_N: 64, num_warps: 8, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 1.131408, nreg 255, nspill 6, #shared-mem 22528
BLOCK_M: 32, BLOCK_N: 16, num_warps: 2, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.516112, nreg 255, nspill 28, #shared-mem 13312
BLOCK_M: 32, BLOCK_N: 16, num_warps: 4, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.737792, nreg 255, nspill 0, #shared-mem 13312
BLOCK_M: 32, BLOCK_N: 16, num_warps: 8, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 1.411632, nreg 193, nspill 0, #shared-mem 13312
BLOCK_M: 32, BLOCK_N: 16, num_warps: 2, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.515904, nreg 255, nspill 28, #shared-mem 13312
BLOCK_M: 32, BLOCK_N: 16, num_warps: 4, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.736608, nreg 255, nspill 0, #shared-mem 13312
BLOCK_M: 32, BLOCK_N: 16, num_warps: 8, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 1.409808, nreg 193, nspill 0, #shared-mem 13312
BLOCK_M: 32, BLOCK_N: 32, num_warps: 2, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.553536, nreg 255, nspill 130, #shared-mem 18432
BLOCK_M: 32, BLOCK_N: 32, num_warps: 4, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.569792, nreg 255, nspill 56, #shared-mem 18432
BLOCK_M: 32, BLOCK_N: 32, num_warps: 8, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.892448, nreg 255, nspill 4, #shared-mem 18432
BLOCK_M: 32, BLOCK_N: 32, num_warps: 2, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.553584, nreg 255, nspill 130, #shared-mem 18432
BLOCK_M: 32, BLOCK_N: 32, num_warps: 4, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.569568, nreg 255, nspill 56, #shared-mem 18432
BLOCK_M: 32, BLOCK_N: 32, num_warps: 8, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.892240, nreg 255, nspill 4, #shared-mem 18432
BLOCK_M: 32, BLOCK_N: 64, num_warps: 2, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 1.332928, nreg 255, nspill 366, #shared-mem 28672
BLOCK_M: 32, BLOCK_N: 64, num_warps: 4, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.922256, nreg 255, nspill 228, #shared-mem 28672
BLOCK_M: 32, BLOCK_N: 64, num_warps: 8, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.758400, nreg 255, nspill 26, #shared-mem 28672
BLOCK_M: 32, BLOCK_N: 64, num_warps: 2, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 1.333440, nreg 255, nspill 366, #shared-mem 28672
BLOCK_M: 32, BLOCK_N: 64, num_warps: 4, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.922336, nreg 255, nspill 228, #shared-mem 28672
BLOCK_M: 32, BLOCK_N: 64, num_warps: 8, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.758496, nreg 255, nspill 26, #shared-mem 28672
BLOCK_M: 64, BLOCK_N: 16, num_warps: 2, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 1.231648, nreg 255, nspill 292, #shared-mem 22528
BLOCK_M: 64, BLOCK_N: 16, num_warps: 4, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.639424, nreg 255, nspill 90, #shared-mem 22528
BLOCK_M: 64, BLOCK_N: 16, num_warps: 8, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.917952, nreg 240, nspill 0, #shared-mem 22528
BLOCK_M: 64, BLOCK_N: 16, num_warps: 2, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 1.230624, nreg 255, nspill 292, #shared-mem 22528
BLOCK_M: 64, BLOCK_N: 16, num_warps: 4, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.639168, nreg 255, nspill 90, #shared-mem 22528
BLOCK_M: 64, BLOCK_N: 16, num_warps: 8, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.917440, nreg 240, nspill 0, #shared-mem 22528
BLOCK_M: 64, BLOCK_N: 32, num_warps: 2, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.838080, nreg 255, nspill 354, #shared-mem 28672
BLOCK_M: 64, BLOCK_N: 32, num_warps: 4, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.569184, nreg 255, nspill 178, #shared-mem 28672
BLOCK_M: 64, BLOCK_N: 32, num_warps: 8, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.614720, nreg 255, nspill 28, #shared-mem 28672
BLOCK_M: 64, BLOCK_N: 32, num_warps: 2, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.838048, nreg 255, nspill 354, #shared-mem 28672
BLOCK_M: 64, BLOCK_N: 32, num_warps: 4, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.569472, nreg 255, nspill 178, #shared-mem 28672
BLOCK_M: 64, BLOCK_N: 32, num_warps: 8, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.615104, nreg 255, nspill 28, #shared-mem 28672
BLOCK_M: 64, BLOCK_N: 64, num_warps: 2, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 1.012128, nreg 255, nspill 522, #shared-mem 40960
BLOCK_M: 64, BLOCK_N: 64, num_warps: 4, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.861536, nreg 255, nspill 378, #shared-mem 40960
BLOCK_M: 64, BLOCK_N: 64, num_warps: 8, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.771584, nreg 255, nspill 134, #shared-mem 40960
BLOCK_M: 64, BLOCK_N: 64, num_warps: 2, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 1.012512, nreg 255, nspill 522, #shared-mem 40960
BLOCK_M: 64, BLOCK_N: 64, num_warps: 4, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.861024, nreg 255, nspill 378, #shared-mem 40960
BLOCK_M: 64, BLOCK_N: 64, num_warps: 8, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.771712, nreg 255, nspill 134, #shared-mem 40960
```

</details>

In the above, the winning config is `BLOCK_M: 32, BLOCK_N: 16, num_warps: 2, num_ctas: 1, num_stages: 2`, although it has non-zero `nspill 28`. This is an example where we need to consider all configs, including the register-spilling ones, to obtain the best result from auto-tuning.

In the worst case, this will just make auto-tuning longer, but can't regress the results. And, as the number of custom Triton kernels in the model is normally much smaller than the number of Inductor-generated ones, this should be acceptable.

Test Plan: CI

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119634
Approved by: https://github.com/oulgen
This commit is contained in:
Adnan Akhundov 2024-02-09 23:24:04 -08:00 committed by PyTorch MergeBot
parent 3ab08946d5
commit 0bed0501fa
2 changed files with 17 additions and 3 deletions

View File

@ -1066,7 +1066,8 @@ class WrapperCodeGen(CodeGen):
configs={configs!r},
inductor_meta={inductor_meta!r},
triton_meta={triton_meta!r},
filename=__file__
filename=__file__,
custom_kernel=True,
)
@triton.jit
"""

View File

@ -144,6 +144,7 @@ class CachingAutotuner(KernelInterface):
heuristic_type,
size_hints=None,
inductor_meta=None, # metadata not relevant to triton
custom_kernel=False, # whether the kernel is inductor-generated or custom
):
super().__init__()
@ -155,6 +156,7 @@ class CachingAutotuner(KernelInterface):
self.mutated_arg_names = mutated_arg_names
self.configs = configs
self.heuristic_type = heuristic_type
self.custom_kernel = custom_kernel
# Align the default design that default as cuda
self.device_type = (
@ -428,7 +430,12 @@ class CachingAutotuner(KernelInterface):
def bench(self, launcher, *args, grid, **kwargs):
"""Measure the performance of a given launcher"""
if launcher.n_spills > config.triton.spill_threshold:
# we don't skip configs wiht spilled registers when auto-tuning custom
# (user-written) Triton kernels, as (i) we don't have any knowledge or
# control over the kernel code; (ii) there is empirical evidence that
# for some (complicated) custom Triton kernels, a register-spilling
# config may yield the best latency.
if not self.custom_kernel and launcher.n_spills > config.triton.spill_threshold:
log.debug(
"Skip config %s because of register spilling: %d",
launcher.config,
@ -803,6 +810,7 @@ def cached_autotune(
heuristic_type,
filename=None,
inductor_meta=None,
custom_kernel=False,
):
"""
A copy of triton.autotune that calls our subclass. Our subclass
@ -867,6 +875,7 @@ def cached_autotune(
mutated_arg_names=mutated_arg_names,
heuristic_type=heuristic_type,
size_hints=size_hints,
custom_kernel=custom_kernel,
)
return CachingAutotuner(
fn,
@ -877,6 +886,7 @@ def cached_autotune(
mutated_arg_names=mutated_arg_names,
heuristic_type=heuristic_type,
size_hints=size_hints,
custom_kernel=custom_kernel,
)
return decorator
@ -1349,7 +1359,9 @@ def template(num_stages, num_warps, triton_meta, filename=None, inductor_meta=No
)
def user_autotune(configs, triton_meta, filename=None, inductor_meta=None):
def user_autotune(
configs, triton_meta, filename=None, inductor_meta=None, custom_kernel=False
):
"""
Compile a user defined triton kernel
"""
@ -1380,6 +1392,7 @@ def user_autotune(configs, triton_meta, filename=None, inductor_meta=None):
heuristic_type=HeuristicType.USER_AUTOTUNE,
filename=filename,
inductor_meta=inductor_meta,
custom_kernel=custom_kernel,
)