[pytorch][triton] Warp specialization support in TritonTemplate for torchinductor (#148503) (#150122)

Summary:
Currently only `num_warps` and `num_stages` are supported as one of the kernel options for inductor auto-tuning using `TritonTemplate`.

In order to allow warp-specialization kernel options should allow specifying `num_consumer_groups` and `num_buffers_warp_spec` as well.

NOTE: Currently gating changes to FBCODE using HAS_WARP_SPEC which is only available on triton/release-3.3.x

Test Plan:
## Unit test
Added tests for `test_triton_template_warp_specialization` to verify generated kenrnel contains configs for  `num_consumer_groups` and `num_buffers_warp_spec`.

## Functional Testing
Specific to flexattention.
```
import torch
from torch.nn.attention.flex_attention import flex_attention

from triton.testing import do_bench

make_tensor = lambda: torch.rand(8, 16, 8192, 128, device="cuda", dtype=torch.bfloat16)
q, k, v = make_tensor(), make_tensor(), make_tensor()

flex_compiled = torch.compile(flex_attention, fullgraph=True)

print(do_bench(lambda: flex_compiled(q, k, v, kernel_options={"num_warps": 4})))
```

triton do_bench results:
- default compile: 15.176783561706543
- with warp-spec: 9.452800750732422

## Extra notes
- generated triton kernel using `TORCH_LOGS=output_code`: P1740612877
- TTGIR for fused kernel: P1740614685

Differential Revision: D71982587

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150122
Approved by: https://github.com/eellison, https://github.com/zou3519, https://github.com/jansel
This commit is contained in:
Mandar Deshpande 2025-03-29 03:36:46 +00:00 committed by PyTorch MergeBot
parent 03313c6619
commit 0861af2596
9 changed files with 225 additions and 21 deletions

View File

@ -16,6 +16,7 @@ from unittest.mock import patch
import torch
from torch._dynamo.testing import CompileCounterWithBackend, normalize_gm
from torch._inductor import metrics
from torch._inductor.runtime.triton_compat import HAS_WARP_SPEC
from torch._inductor.test_case import TestCase as InductorTestCase
from torch._inductor.utils import run_and_get_code
from torch.nn.attention.experimental._paged_attention import PagedAttention
@ -3708,6 +3709,48 @@ class GraphModule(torch.nn.Module):
q, k, v = [torch.randn(2, 2, 128, 16, device="cuda") for _ in range(3)]
compiled_fa = torch.compile(flex_attention)
@unittest.skipIf(
not has_triton() or not HAS_WARP_SPEC,
reason="FBCODE Triton is required for this test",
)
def test_triton_template_warp_specialization(self):
def make_tensor():
return torch.rand(4, 16, 4096, 64, device="cuda", dtype=torch.bfloat16)
q, k, v = make_tensor(), make_tensor(), make_tensor()
flex_compiled = torch.compile(flex_attention, fullgraph=True)
positional_args = (q, k, v)
keyword_args = {
"kernel_options": {
"num_warps": 4,
"num_consumer_groups": 0,
"num_buffers_warp_spec": 0,
}
}
# Check if kernel code contains warp specialization parameters
_, kernel_code = run_and_get_code(
flex_compiled,
*positional_args,
**keyword_args,
)
assert kernel_code is not None, "Failed to retrieve compiled kernel code"
assert (
"num_consumer_groups" in kernel_code[0]
), "num_consumer_groups missing in kernel definition"
assert (
"num_buffers_warp_spec" in kernel_code[0]
), "num_buffers_warp_spec missing in kernel definition"
# Validate correctness
C1 = flex_compiled(q, k, v)
C2 = flex_attention(q, k, v)
assert torch.allclose(
C1, C2, atol=1e-2, rtol=1e-2
), "Warp specialized kernel result differs from reference"
class TestBlockMask(InductorTestCase):
@supported_platform

View File

@ -357,6 +357,8 @@ class TestSelectAlgorithm(TestCase):
extra_args=None,
num_stages=None,
num_warps=None,
num_consumer_groups=None,
num_buffers_warp_spec=None,
input_tensor_meta=None,
output_tensor_meta=None,
)

View File

@ -2,9 +2,11 @@
import sys
import unittest
from unittest.mock import MagicMock, patch
import torch
from torch._dynamo.testing import rand_strided
from torch._inductor.runtime.triton_compat import HAS_WARP_SPEC
from torch._inductor.utils import clone_preserve_strides
from torch.testing._internal.common_utils import IS_LINUX, skipIfXpu
from torch.testing._internal.inductor_utils import (
@ -34,6 +36,7 @@ from torch._inductor.runtime.triton_helpers import math as tl_math
from torch._inductor.runtime.triton_heuristics import (
autotune_hints_to_configs,
CachingAutotuner,
template,
triton_config,
)
from torch._inductor.test_case import run_tests, TestCase
@ -185,6 +188,29 @@ class TestTritonHeuristics(TestCase):
self.assertTrue(8 in seen_num_elements_per_warp)
@unittest.skipIf(not HAS_WARP_SPEC, "FBCODE Triton is required for this test")
def test_template_function_ws(self):
triton_meta = {"device": MagicMock()}
num_stages = 2
num_warps = 4
num_consumer_groups = 3
num_buffers_warp_spec = 5
with patch(
"torch._inductor.runtime.triton_heuristics.cached_autotune"
) as mock_cached_autotune:
template(
num_stages=num_stages,
num_warps=num_warps,
triton_meta=triton_meta,
num_consumer_groups=num_consumer_groups,
num_buffers_warp_spec=num_buffers_warp_spec,
)
mock_cached_autotune.assert_called_once()
configs = mock_cached_autotune.call_args[0][1]
self.assertEqual(configs[0].num_consumer_groups, num_consumer_groups)
self.assertEqual(configs[0].num_buffers_warp_spec, num_buffers_warp_spec)
class TestArgumentCloneAndRestore(TestCase):
# Our tensor is large enough. If a unexpected copy happens, the

View File

@ -1473,7 +1473,13 @@ class TritonHOPifier:
new_var = type(variable)(new_kernel, None, variable.grid)
return self.call_triton_kernel(new_var, args, kwargs, tx)
SPECIAL_CONFIG_NAMES = {"num_warps", "num_stages", "num_ctas"}
SPECIAL_CONFIG_NAMES = {
"num_warps",
"num_stages",
"num_ctas",
"num_consumer_groups",
"num_buffers_warp_spec",
}
# move special config names to configs out of kwargs
special_kwargs = {}

View File

@ -578,6 +578,8 @@ class TritonBenchmarkRequest(BenchmarkRequest):
module_cache_key: str,
num_stages: int,
num_warps: int,
num_consumer_groups: int = 0,
num_buffers_warp_spec: int = 0,
matrix_instr_nonkdim: int = 0, # only used for hip to choose the shape of mfma instruction.
waves_per_eu: int = 0, # only used for hip to schedule waves per execution unit
kpack: int = 0, # ROCm specific gemm paramete
@ -588,6 +590,8 @@ class TritonBenchmarkRequest(BenchmarkRequest):
self.module_cache_key = module_cache_key
self.num_stages = num_stages
self.num_warps = num_warps
self.num_consumer_groups = num_consumer_groups
self.num_buffers_warp_spec = num_buffers_warp_spec
self.matrix_instr_nonkdim = matrix_instr_nonkdim
self.waves_per_eu = waves_per_eu
self.kpack = kpack

View File

@ -20,7 +20,7 @@ from ..remote_cache import (
RemoteCacheBackend,
RemoteCacheJsonSerde,
)
from .triton_compat import Config
from .triton_compat import Config, HAS_WARP_SPEC
if TYPE_CHECKING:
@ -207,6 +207,15 @@ class AutotuneCache:
"found_by_coordesc": found_by_coordesc,
"time_taken_ms": time_taken_ns // 1000000, # Convert from NS to MS
}
if HAS_WARP_SPEC:
data.update(
{
"num_consumer_groups": getattr(config, "num_consumer_groups", 0),
"num_buffers_warp_spec": getattr(
config, "num_buffers_warp_spec", 0
),
}
)
if local_cache := self.local_cache:
cache, key = local_cache
@ -464,7 +473,25 @@ def _load_cached_autotuning(
):
num_warps = best_config.pop("num_warps")
num_stages = best_config.pop("num_stages")
triton_config = Config(best_config, num_warps=num_warps, num_stages=num_stages)
# Extract common arguments
config_args = {
"num_warps": num_warps,
"num_stages": num_stages,
}
if HAS_WARP_SPEC:
config_args.update(
{
"num_consumer_groups": best_config.pop("num_consumer_groups", 0),
"num_buffers_warp_spec": best_config.pop(
"num_buffers_warp_spec", 0
),
}
)
# Create the triton_config with the appropriate arguments
triton_config = Config(best_config, **config_args)
triton_config.found_by_coordesc = True
return triton_config

View File

@ -68,6 +68,7 @@ if triton is not None:
def _log2(x: Any) -> Any:
raise NotImplementedError
HAS_WARP_SPEC = hasattr(tl, "async_task")
else:
def _raise_error(*args: Any, **kwargs: Any) -> Any:
@ -101,6 +102,8 @@ else:
tensor = Any
dtype = Any
HAS_WARP_SPEC = False
def cc_warp_size(cc: Union[str, int]) -> int:
if torch.version.hip:

View File

@ -71,6 +71,7 @@ from .triton_compat import (
CompiledKernel,
Config,
GPUTarget,
HAS_WARP_SPEC,
KernelInterface,
OutOfResources,
PTXASError,
@ -169,6 +170,13 @@ def _dump_launch_params(args, kwargs, launcher, kernel_name, grid):
call_kwargs.update(launcher.config.kwargs)
call_kwargs["num_warps"] = launcher.config.num_warps
call_kwargs["num_stages"] = launcher.config.num_stages
if HAS_WARP_SPEC:
call_kwargs["num_consumer_groups"] = getattr(
launcher.config, "num_consumer_groups", 0
)
call_kwargs["num_buffers_warp_spec"] = getattr(
launcher.config, "num_buffers_warp_spec", 0
)
args_str = [*call_args]
args_str.extend(f"{k}={v}" for k, v in call_kwargs.items())
args_str = ", ".join(args_str)
@ -509,6 +517,11 @@ class CachingAutotuner(KernelInterface):
compile_meta["constants"][arg_name] = getattr(cfg, arg_name)
compile_meta["num_warps"] = cfg.num_warps
compile_meta["num_stages"] = cfg.num_stages
if HAS_WARP_SPEC:
compile_meta["num_consumer_groups"] = getattr(cfg, "num_consumer_groups", 0)
compile_meta["num_buffers_warp_spec"] = getattr(
cfg, "num_buffers_warp_spec", 0
)
compile_meta["debug"] = self.inductor_meta.get(
"assert_indirect_indexing", True
) and not self.inductor_meta.get("is_hip", False)
@ -555,6 +568,15 @@ class CachingAutotuner(KernelInterface):
"debug": compile_meta["debug"],
"sanitize_overflow": False, # turn off additional asserts added for overflow checks
}
if HAS_WARP_SPEC:
options.update(
{
"num_consumer_groups": compile_meta.get("num_consumer_groups", 0),
"num_buffers_warp_spec": compile_meta.get(
"num_buffers_warp_spec", 0
),
}
)
if self.device_props.type == "hip":
if "waves_per_eu" in compile_meta:
options["waves_per_eu"] = compile_meta["waves_per_eu"]
@ -2326,13 +2348,35 @@ def split_scan(
)
def template(num_stages, num_warps, triton_meta, filename=None, inductor_meta=None):
def template(
num_stages,
num_warps,
triton_meta,
num_consumer_groups=0,
num_buffers_warp_spec=0,
filename=None,
inductor_meta=None,
):
"""
Compile a triton template
"""
# Prepare the base configuration
config_args = {
"num_stages": num_stages,
"num_warps": num_warps,
}
# Conditionally add arguments based on HAS_WARP_SPEC
if HAS_WARP_SPEC:
config_args.update(
{
"num_consumer_groups": num_consumer_groups,
"num_buffers_warp_spec": num_buffers_warp_spec,
}
)
return cached_autotune(
None,
[triton.Config({}, num_stages=num_stages, num_warps=num_warps)],
[triton.Config({}, **config_args)],
triton_meta=triton_meta,
inductor_meta=inductor_meta,
heuristic_type=HeuristicType.TEMPLATE,
@ -2343,7 +2387,14 @@ def template(num_stages, num_warps, triton_meta, filename=None, inductor_meta=No
def _pop_config_kwargs(config: dict[str, Any]) -> dict[str, Any]:
"""Extract triton.Config options that should become kwargs"""
popped = {}
for key in ("num_warps", "num_stages", "num_ctas", "maxnreg"):
for key in (
"num_warps",
"num_stages",
"num_ctas",
"maxnreg",
"num_consumer_groups",
"num_buffers_warp_spec",
):
val = config.pop(key, None)
if val is not None:
popped[key] = val
@ -2351,11 +2402,19 @@ def _pop_config_kwargs(config: dict[str, Any]) -> dict[str, Any]:
def config_to_dict(config: Config) -> dict[str, Any]:
return {
config_dict = {
**config.kwargs,
"num_warps": config.num_warps,
"num_stages": config.num_stages,
}
if HAS_WARP_SPEC:
config_dict.update(
{
"num_consumer_groups": getattr(config, "num_consumer_groups", 0),
"num_buffers_warp_spec": getattr(config, "num_buffers_warp_spec", 0),
}
)
return config_dict
def config_from_dict(config: dict[str, Any]) -> Config:

View File

@ -61,6 +61,7 @@ from .ir import ChoiceCaller, PrimitiveInfoType
from .ops_handler import StoreMode
from .runtime.benchmarking import benchmarker
from .runtime.hints import DeviceProperties
from .runtime.triton_compat import HAS_WARP_SPEC
from .runtime.triton_heuristics import FixedGrid
from .utils import (
ceildiv,
@ -86,6 +87,7 @@ VERIFY: dict[str, Any] = {}
PRINT_AUTOTUNE = True
DEBUG = False
if TYPE_CHECKING:
import concurrent
@ -293,6 +295,8 @@ class TritonTemplateKernel(TritonKernel):
grid_fn,
meta,
call_sizes,
num_consumer_groups=0,
num_buffers_warp_spec=0,
use_jit=False,
prefix_args=0,
suffix_args=0,
@ -316,6 +320,8 @@ class TritonTemplateKernel(TritonKernel):
self.use_jit = use_jit
self.num_stages = num_stages
self.num_warps = num_warps
self.num_consumer_groups = num_consumer_groups
self.num_buffers_warp_spec = num_buffers_warp_spec
self.grid_fn = grid_fn
self.meta = meta
self.call_sizes = call_sizes
@ -457,12 +463,23 @@ class TritonTemplateKernel(TritonKernel):
if config.profile_bandwidth or config.benchmark_kernel:
num_gb = self.estimate_kernel_num_bytes() / 1e9
inductor_meta["kernel_num_gb"] = num_gb
template_args = f"""
num_stages={self.num_stages},
num_warps={self.num_warps},
triton_meta={triton_meta!r},
inductor_meta={inductor_meta!r},
"""
if HAS_WARP_SPEC:
template_args += f"""
num_consumer_groups={self.num_consumer_groups},
num_buffers_warp_spec={self.num_buffers_warp_spec},
"""
return f"""
@triton_heuristics.template(
num_stages={self.num_stages},
num_warps={self.num_warps},
triton_meta={triton_meta!r},
inductor_meta={inductor_meta!r},
{template_args}
)
@triton.jit
"""
@ -1070,6 +1087,8 @@ class TritonTemplate(KernelTemplate):
layout,
num_stages,
num_warps,
num_consumer_groups=0,
num_buffers_warp_spec=0,
prefix_args=0,
suffix_args=0,
epilogue_fn=identity,
@ -1135,6 +1154,13 @@ class TritonTemplate(KernelTemplate):
"epilogue_fn": epilogue_fn,
"subgraphs": subgraphs,
}
if HAS_WARP_SPEC:
kernel_options.update(
{
"num_consumer_groups": num_consumer_groups,
"num_buffers_warp_spec": num_buffers_warp_spec,
}
)
with (
patch.object(V.graph, "get_dtype", self._fake_get_dtype(fake_out)),
@ -1156,19 +1182,25 @@ class TritonTemplate(KernelTemplate):
return None
if self.debug:
print("Generated Code:\n", code)
extra = (
"-".join(
extra_parts = [
f"{kwarg}={repr(kwargs[kwarg])}" for kwarg in sorted(kwargs.keys())
]
extra_parts.extend(
[
f"num_stages={num_stages}",
f"num_warps={num_warps}",
]
)
if HAS_WARP_SPEC:
extra_parts.extend(
[
*[
f"{kwarg}={repr(kwargs[kwarg])}"
for kwarg in sorted(kwargs.keys())
],
f"num_stages={num_stages}",
f"num_warps={num_warps}",
f"num_consumer_groups={num_consumer_groups}",
f"num_buffers_warp_spec={num_buffers_warp_spec}",
]
)
+ "-"
)
extra = "-".join(extra_parts) + "-"
mod = PyCodeCache.load(code, extra)
input_call_args = tuple(kernel.args.input_buffers.keys())
@ -1224,6 +1256,8 @@ class TritonTemplate(KernelTemplate):
extra_args=[*extra_args, *grid],
num_stages=num_stages,
num_warps=num_warps,
num_consumer_groups=num_consumer_groups,
num_buffers_warp_spec=num_buffers_warp_spec,
matrix_instr_nonkdim=kwargs.get("matrix_instr_nonkdim", 0),
waves_per_eu=kwargs.get("waves_per_eu", 0),
kpack=kwargs.get("kpack", 2),