mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[pytorch][triton] Warp specialization support in TritonTemplate for torchinductor (#148503) (#150122)
Summary:
Currently only `num_warps` and `num_stages` are supported as one of the kernel options for inductor auto-tuning using `TritonTemplate`.
In order to allow warp-specialization kernel options should allow specifying `num_consumer_groups` and `num_buffers_warp_spec` as well.
NOTE: Currently gating changes to FBCODE using HAS_WARP_SPEC which is only available on triton/release-3.3.x
Test Plan:
## Unit test
Added tests for `test_triton_template_warp_specialization` to verify generated kenrnel contains configs for `num_consumer_groups` and `num_buffers_warp_spec`.
## Functional Testing
Specific to flexattention.
```
import torch
from torch.nn.attention.flex_attention import flex_attention
from triton.testing import do_bench
make_tensor = lambda: torch.rand(8, 16, 8192, 128, device="cuda", dtype=torch.bfloat16)
q, k, v = make_tensor(), make_tensor(), make_tensor()
flex_compiled = torch.compile(flex_attention, fullgraph=True)
print(do_bench(lambda: flex_compiled(q, k, v, kernel_options={"num_warps": 4})))
```
triton do_bench results:
- default compile: 15.176783561706543
- with warp-spec: 9.452800750732422
## Extra notes
- generated triton kernel using `TORCH_LOGS=output_code`: P1740612877
- TTGIR for fused kernel: P1740614685
Differential Revision: D71982587
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150122
Approved by: https://github.com/eellison, https://github.com/zou3519, https://github.com/jansel
This commit is contained in:
parent
03313c6619
commit
0861af2596
|
|
@ -16,6 +16,7 @@ from unittest.mock import patch
|
|||
import torch
|
||||
from torch._dynamo.testing import CompileCounterWithBackend, normalize_gm
|
||||
from torch._inductor import metrics
|
||||
from torch._inductor.runtime.triton_compat import HAS_WARP_SPEC
|
||||
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
from torch.nn.attention.experimental._paged_attention import PagedAttention
|
||||
|
|
@ -3708,6 +3709,48 @@ class GraphModule(torch.nn.Module):
|
|||
q, k, v = [torch.randn(2, 2, 128, 16, device="cuda") for _ in range(3)]
|
||||
compiled_fa = torch.compile(flex_attention)
|
||||
|
||||
@unittest.skipIf(
|
||||
not has_triton() or not HAS_WARP_SPEC,
|
||||
reason="FBCODE Triton is required for this test",
|
||||
)
|
||||
def test_triton_template_warp_specialization(self):
|
||||
def make_tensor():
|
||||
return torch.rand(4, 16, 4096, 64, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
q, k, v = make_tensor(), make_tensor(), make_tensor()
|
||||
flex_compiled = torch.compile(flex_attention, fullgraph=True)
|
||||
|
||||
positional_args = (q, k, v)
|
||||
keyword_args = {
|
||||
"kernel_options": {
|
||||
"num_warps": 4,
|
||||
"num_consumer_groups": 0,
|
||||
"num_buffers_warp_spec": 0,
|
||||
}
|
||||
}
|
||||
|
||||
# Check if kernel code contains warp specialization parameters
|
||||
_, kernel_code = run_and_get_code(
|
||||
flex_compiled,
|
||||
*positional_args,
|
||||
**keyword_args,
|
||||
)
|
||||
assert kernel_code is not None, "Failed to retrieve compiled kernel code"
|
||||
assert (
|
||||
"num_consumer_groups" in kernel_code[0]
|
||||
), "num_consumer_groups missing in kernel definition"
|
||||
assert (
|
||||
"num_buffers_warp_spec" in kernel_code[0]
|
||||
), "num_buffers_warp_spec missing in kernel definition"
|
||||
|
||||
# Validate correctness
|
||||
C1 = flex_compiled(q, k, v)
|
||||
C2 = flex_attention(q, k, v)
|
||||
|
||||
assert torch.allclose(
|
||||
C1, C2, atol=1e-2, rtol=1e-2
|
||||
), "Warp specialized kernel result differs from reference"
|
||||
|
||||
|
||||
class TestBlockMask(InductorTestCase):
|
||||
@supported_platform
|
||||
|
|
|
|||
|
|
@ -357,6 +357,8 @@ class TestSelectAlgorithm(TestCase):
|
|||
extra_args=None,
|
||||
num_stages=None,
|
||||
num_warps=None,
|
||||
num_consumer_groups=None,
|
||||
num_buffers_warp_spec=None,
|
||||
input_tensor_meta=None,
|
||||
output_tensor_meta=None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,9 +2,11 @@
|
|||
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from torch._dynamo.testing import rand_strided
|
||||
from torch._inductor.runtime.triton_compat import HAS_WARP_SPEC
|
||||
from torch._inductor.utils import clone_preserve_strides
|
||||
from torch.testing._internal.common_utils import IS_LINUX, skipIfXpu
|
||||
from torch.testing._internal.inductor_utils import (
|
||||
|
|
@ -34,6 +36,7 @@ from torch._inductor.runtime.triton_helpers import math as tl_math
|
|||
from torch._inductor.runtime.triton_heuristics import (
|
||||
autotune_hints_to_configs,
|
||||
CachingAutotuner,
|
||||
template,
|
||||
triton_config,
|
||||
)
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
|
|
@ -185,6 +188,29 @@ class TestTritonHeuristics(TestCase):
|
|||
|
||||
self.assertTrue(8 in seen_num_elements_per_warp)
|
||||
|
||||
@unittest.skipIf(not HAS_WARP_SPEC, "FBCODE Triton is required for this test")
|
||||
def test_template_function_ws(self):
|
||||
triton_meta = {"device": MagicMock()}
|
||||
num_stages = 2
|
||||
num_warps = 4
|
||||
num_consumer_groups = 3
|
||||
num_buffers_warp_spec = 5
|
||||
|
||||
with patch(
|
||||
"torch._inductor.runtime.triton_heuristics.cached_autotune"
|
||||
) as mock_cached_autotune:
|
||||
template(
|
||||
num_stages=num_stages,
|
||||
num_warps=num_warps,
|
||||
triton_meta=triton_meta,
|
||||
num_consumer_groups=num_consumer_groups,
|
||||
num_buffers_warp_spec=num_buffers_warp_spec,
|
||||
)
|
||||
mock_cached_autotune.assert_called_once()
|
||||
configs = mock_cached_autotune.call_args[0][1]
|
||||
self.assertEqual(configs[0].num_consumer_groups, num_consumer_groups)
|
||||
self.assertEqual(configs[0].num_buffers_warp_spec, num_buffers_warp_spec)
|
||||
|
||||
|
||||
class TestArgumentCloneAndRestore(TestCase):
|
||||
# Our tensor is large enough. If a unexpected copy happens, the
|
||||
|
|
|
|||
|
|
@ -1473,7 +1473,13 @@ class TritonHOPifier:
|
|||
new_var = type(variable)(new_kernel, None, variable.grid)
|
||||
return self.call_triton_kernel(new_var, args, kwargs, tx)
|
||||
|
||||
SPECIAL_CONFIG_NAMES = {"num_warps", "num_stages", "num_ctas"}
|
||||
SPECIAL_CONFIG_NAMES = {
|
||||
"num_warps",
|
||||
"num_stages",
|
||||
"num_ctas",
|
||||
"num_consumer_groups",
|
||||
"num_buffers_warp_spec",
|
||||
}
|
||||
|
||||
# move special config names to configs out of kwargs
|
||||
special_kwargs = {}
|
||||
|
|
|
|||
|
|
@ -578,6 +578,8 @@ class TritonBenchmarkRequest(BenchmarkRequest):
|
|||
module_cache_key: str,
|
||||
num_stages: int,
|
||||
num_warps: int,
|
||||
num_consumer_groups: int = 0,
|
||||
num_buffers_warp_spec: int = 0,
|
||||
matrix_instr_nonkdim: int = 0, # only used for hip to choose the shape of mfma instruction.
|
||||
waves_per_eu: int = 0, # only used for hip to schedule waves per execution unit
|
||||
kpack: int = 0, # ROCm specific gemm paramete
|
||||
|
|
@ -588,6 +590,8 @@ class TritonBenchmarkRequest(BenchmarkRequest):
|
|||
self.module_cache_key = module_cache_key
|
||||
self.num_stages = num_stages
|
||||
self.num_warps = num_warps
|
||||
self.num_consumer_groups = num_consumer_groups
|
||||
self.num_buffers_warp_spec = num_buffers_warp_spec
|
||||
self.matrix_instr_nonkdim = matrix_instr_nonkdim
|
||||
self.waves_per_eu = waves_per_eu
|
||||
self.kpack = kpack
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from ..remote_cache import (
|
|||
RemoteCacheBackend,
|
||||
RemoteCacheJsonSerde,
|
||||
)
|
||||
from .triton_compat import Config
|
||||
from .triton_compat import Config, HAS_WARP_SPEC
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -207,6 +207,15 @@ class AutotuneCache:
|
|||
"found_by_coordesc": found_by_coordesc,
|
||||
"time_taken_ms": time_taken_ns // 1000000, # Convert from NS to MS
|
||||
}
|
||||
if HAS_WARP_SPEC:
|
||||
data.update(
|
||||
{
|
||||
"num_consumer_groups": getattr(config, "num_consumer_groups", 0),
|
||||
"num_buffers_warp_spec": getattr(
|
||||
config, "num_buffers_warp_spec", 0
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
if local_cache := self.local_cache:
|
||||
cache, key = local_cache
|
||||
|
|
@ -464,7 +473,25 @@ def _load_cached_autotuning(
|
|||
):
|
||||
num_warps = best_config.pop("num_warps")
|
||||
num_stages = best_config.pop("num_stages")
|
||||
triton_config = Config(best_config, num_warps=num_warps, num_stages=num_stages)
|
||||
|
||||
# Extract common arguments
|
||||
config_args = {
|
||||
"num_warps": num_warps,
|
||||
"num_stages": num_stages,
|
||||
}
|
||||
|
||||
if HAS_WARP_SPEC:
|
||||
config_args.update(
|
||||
{
|
||||
"num_consumer_groups": best_config.pop("num_consumer_groups", 0),
|
||||
"num_buffers_warp_spec": best_config.pop(
|
||||
"num_buffers_warp_spec", 0
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# Create the triton_config with the appropriate arguments
|
||||
triton_config = Config(best_config, **config_args)
|
||||
triton_config.found_by_coordesc = True
|
||||
return triton_config
|
||||
|
||||
|
|
|
|||
|
|
@ -68,6 +68,7 @@ if triton is not None:
|
|||
def _log2(x: Any) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
HAS_WARP_SPEC = hasattr(tl, "async_task")
|
||||
else:
|
||||
|
||||
def _raise_error(*args: Any, **kwargs: Any) -> Any:
|
||||
|
|
@ -101,6 +102,8 @@ else:
|
|||
tensor = Any
|
||||
dtype = Any
|
||||
|
||||
HAS_WARP_SPEC = False
|
||||
|
||||
|
||||
def cc_warp_size(cc: Union[str, int]) -> int:
|
||||
if torch.version.hip:
|
||||
|
|
|
|||
|
|
@ -71,6 +71,7 @@ from .triton_compat import (
|
|||
CompiledKernel,
|
||||
Config,
|
||||
GPUTarget,
|
||||
HAS_WARP_SPEC,
|
||||
KernelInterface,
|
||||
OutOfResources,
|
||||
PTXASError,
|
||||
|
|
@ -169,6 +170,13 @@ def _dump_launch_params(args, kwargs, launcher, kernel_name, grid):
|
|||
call_kwargs.update(launcher.config.kwargs)
|
||||
call_kwargs["num_warps"] = launcher.config.num_warps
|
||||
call_kwargs["num_stages"] = launcher.config.num_stages
|
||||
if HAS_WARP_SPEC:
|
||||
call_kwargs["num_consumer_groups"] = getattr(
|
||||
launcher.config, "num_consumer_groups", 0
|
||||
)
|
||||
call_kwargs["num_buffers_warp_spec"] = getattr(
|
||||
launcher.config, "num_buffers_warp_spec", 0
|
||||
)
|
||||
args_str = [*call_args]
|
||||
args_str.extend(f"{k}={v}" for k, v in call_kwargs.items())
|
||||
args_str = ", ".join(args_str)
|
||||
|
|
@ -509,6 +517,11 @@ class CachingAutotuner(KernelInterface):
|
|||
compile_meta["constants"][arg_name] = getattr(cfg, arg_name)
|
||||
compile_meta["num_warps"] = cfg.num_warps
|
||||
compile_meta["num_stages"] = cfg.num_stages
|
||||
if HAS_WARP_SPEC:
|
||||
compile_meta["num_consumer_groups"] = getattr(cfg, "num_consumer_groups", 0)
|
||||
compile_meta["num_buffers_warp_spec"] = getattr(
|
||||
cfg, "num_buffers_warp_spec", 0
|
||||
)
|
||||
compile_meta["debug"] = self.inductor_meta.get(
|
||||
"assert_indirect_indexing", True
|
||||
) and not self.inductor_meta.get("is_hip", False)
|
||||
|
|
@ -555,6 +568,15 @@ class CachingAutotuner(KernelInterface):
|
|||
"debug": compile_meta["debug"],
|
||||
"sanitize_overflow": False, # turn off additional asserts added for overflow checks
|
||||
}
|
||||
if HAS_WARP_SPEC:
|
||||
options.update(
|
||||
{
|
||||
"num_consumer_groups": compile_meta.get("num_consumer_groups", 0),
|
||||
"num_buffers_warp_spec": compile_meta.get(
|
||||
"num_buffers_warp_spec", 0
|
||||
),
|
||||
}
|
||||
)
|
||||
if self.device_props.type == "hip":
|
||||
if "waves_per_eu" in compile_meta:
|
||||
options["waves_per_eu"] = compile_meta["waves_per_eu"]
|
||||
|
|
@ -2326,13 +2348,35 @@ def split_scan(
|
|||
)
|
||||
|
||||
|
||||
def template(num_stages, num_warps, triton_meta, filename=None, inductor_meta=None):
|
||||
def template(
|
||||
num_stages,
|
||||
num_warps,
|
||||
triton_meta,
|
||||
num_consumer_groups=0,
|
||||
num_buffers_warp_spec=0,
|
||||
filename=None,
|
||||
inductor_meta=None,
|
||||
):
|
||||
"""
|
||||
Compile a triton template
|
||||
"""
|
||||
# Prepare the base configuration
|
||||
config_args = {
|
||||
"num_stages": num_stages,
|
||||
"num_warps": num_warps,
|
||||
}
|
||||
|
||||
# Conditionally add arguments based on HAS_WARP_SPEC
|
||||
if HAS_WARP_SPEC:
|
||||
config_args.update(
|
||||
{
|
||||
"num_consumer_groups": num_consumer_groups,
|
||||
"num_buffers_warp_spec": num_buffers_warp_spec,
|
||||
}
|
||||
)
|
||||
return cached_autotune(
|
||||
None,
|
||||
[triton.Config({}, num_stages=num_stages, num_warps=num_warps)],
|
||||
[triton.Config({}, **config_args)],
|
||||
triton_meta=triton_meta,
|
||||
inductor_meta=inductor_meta,
|
||||
heuristic_type=HeuristicType.TEMPLATE,
|
||||
|
|
@ -2343,7 +2387,14 @@ def template(num_stages, num_warps, triton_meta, filename=None, inductor_meta=No
|
|||
def _pop_config_kwargs(config: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Extract triton.Config options that should become kwargs"""
|
||||
popped = {}
|
||||
for key in ("num_warps", "num_stages", "num_ctas", "maxnreg"):
|
||||
for key in (
|
||||
"num_warps",
|
||||
"num_stages",
|
||||
"num_ctas",
|
||||
"maxnreg",
|
||||
"num_consumer_groups",
|
||||
"num_buffers_warp_spec",
|
||||
):
|
||||
val = config.pop(key, None)
|
||||
if val is not None:
|
||||
popped[key] = val
|
||||
|
|
@ -2351,11 +2402,19 @@ def _pop_config_kwargs(config: dict[str, Any]) -> dict[str, Any]:
|
|||
|
||||
|
||||
def config_to_dict(config: Config) -> dict[str, Any]:
|
||||
return {
|
||||
config_dict = {
|
||||
**config.kwargs,
|
||||
"num_warps": config.num_warps,
|
||||
"num_stages": config.num_stages,
|
||||
}
|
||||
if HAS_WARP_SPEC:
|
||||
config_dict.update(
|
||||
{
|
||||
"num_consumer_groups": getattr(config, "num_consumer_groups", 0),
|
||||
"num_buffers_warp_spec": getattr(config, "num_buffers_warp_spec", 0),
|
||||
}
|
||||
)
|
||||
return config_dict
|
||||
|
||||
|
||||
def config_from_dict(config: dict[str, Any]) -> Config:
|
||||
|
|
|
|||
|
|
@ -61,6 +61,7 @@ from .ir import ChoiceCaller, PrimitiveInfoType
|
|||
from .ops_handler import StoreMode
|
||||
from .runtime.benchmarking import benchmarker
|
||||
from .runtime.hints import DeviceProperties
|
||||
from .runtime.triton_compat import HAS_WARP_SPEC
|
||||
from .runtime.triton_heuristics import FixedGrid
|
||||
from .utils import (
|
||||
ceildiv,
|
||||
|
|
@ -86,6 +87,7 @@ VERIFY: dict[str, Any] = {}
|
|||
PRINT_AUTOTUNE = True
|
||||
DEBUG = False
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import concurrent
|
||||
|
||||
|
|
@ -293,6 +295,8 @@ class TritonTemplateKernel(TritonKernel):
|
|||
grid_fn,
|
||||
meta,
|
||||
call_sizes,
|
||||
num_consumer_groups=0,
|
||||
num_buffers_warp_spec=0,
|
||||
use_jit=False,
|
||||
prefix_args=0,
|
||||
suffix_args=0,
|
||||
|
|
@ -316,6 +320,8 @@ class TritonTemplateKernel(TritonKernel):
|
|||
self.use_jit = use_jit
|
||||
self.num_stages = num_stages
|
||||
self.num_warps = num_warps
|
||||
self.num_consumer_groups = num_consumer_groups
|
||||
self.num_buffers_warp_spec = num_buffers_warp_spec
|
||||
self.grid_fn = grid_fn
|
||||
self.meta = meta
|
||||
self.call_sizes = call_sizes
|
||||
|
|
@ -457,12 +463,23 @@ class TritonTemplateKernel(TritonKernel):
|
|||
if config.profile_bandwidth or config.benchmark_kernel:
|
||||
num_gb = self.estimate_kernel_num_bytes() / 1e9
|
||||
inductor_meta["kernel_num_gb"] = num_gb
|
||||
|
||||
template_args = f"""
|
||||
num_stages={self.num_stages},
|
||||
num_warps={self.num_warps},
|
||||
triton_meta={triton_meta!r},
|
||||
inductor_meta={inductor_meta!r},
|
||||
"""
|
||||
|
||||
if HAS_WARP_SPEC:
|
||||
template_args += f"""
|
||||
num_consumer_groups={self.num_consumer_groups},
|
||||
num_buffers_warp_spec={self.num_buffers_warp_spec},
|
||||
"""
|
||||
|
||||
return f"""
|
||||
@triton_heuristics.template(
|
||||
num_stages={self.num_stages},
|
||||
num_warps={self.num_warps},
|
||||
triton_meta={triton_meta!r},
|
||||
inductor_meta={inductor_meta!r},
|
||||
{template_args}
|
||||
)
|
||||
@triton.jit
|
||||
"""
|
||||
|
|
@ -1070,6 +1087,8 @@ class TritonTemplate(KernelTemplate):
|
|||
layout,
|
||||
num_stages,
|
||||
num_warps,
|
||||
num_consumer_groups=0,
|
||||
num_buffers_warp_spec=0,
|
||||
prefix_args=0,
|
||||
suffix_args=0,
|
||||
epilogue_fn=identity,
|
||||
|
|
@ -1135,6 +1154,13 @@ class TritonTemplate(KernelTemplate):
|
|||
"epilogue_fn": epilogue_fn,
|
||||
"subgraphs": subgraphs,
|
||||
}
|
||||
if HAS_WARP_SPEC:
|
||||
kernel_options.update(
|
||||
{
|
||||
"num_consumer_groups": num_consumer_groups,
|
||||
"num_buffers_warp_spec": num_buffers_warp_spec,
|
||||
}
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(V.graph, "get_dtype", self._fake_get_dtype(fake_out)),
|
||||
|
|
@ -1156,19 +1182,25 @@ class TritonTemplate(KernelTemplate):
|
|||
return None
|
||||
if self.debug:
|
||||
print("Generated Code:\n", code)
|
||||
extra = (
|
||||
"-".join(
|
||||
extra_parts = [
|
||||
f"{kwarg}={repr(kwargs[kwarg])}" for kwarg in sorted(kwargs.keys())
|
||||
]
|
||||
|
||||
extra_parts.extend(
|
||||
[
|
||||
f"num_stages={num_stages}",
|
||||
f"num_warps={num_warps}",
|
||||
]
|
||||
)
|
||||
if HAS_WARP_SPEC:
|
||||
extra_parts.extend(
|
||||
[
|
||||
*[
|
||||
f"{kwarg}={repr(kwargs[kwarg])}"
|
||||
for kwarg in sorted(kwargs.keys())
|
||||
],
|
||||
f"num_stages={num_stages}",
|
||||
f"num_warps={num_warps}",
|
||||
f"num_consumer_groups={num_consumer_groups}",
|
||||
f"num_buffers_warp_spec={num_buffers_warp_spec}",
|
||||
]
|
||||
)
|
||||
+ "-"
|
||||
)
|
||||
|
||||
extra = "-".join(extra_parts) + "-"
|
||||
mod = PyCodeCache.load(code, extra)
|
||||
|
||||
input_call_args = tuple(kernel.args.input_buffers.keys())
|
||||
|
|
@ -1224,6 +1256,8 @@ class TritonTemplate(KernelTemplate):
|
|||
extra_args=[*extra_args, *grid],
|
||||
num_stages=num_stages,
|
||||
num_warps=num_warps,
|
||||
num_consumer_groups=num_consumer_groups,
|
||||
num_buffers_warp_spec=num_buffers_warp_spec,
|
||||
matrix_instr_nonkdim=kwargs.get("matrix_instr_nonkdim", 0),
|
||||
waves_per_eu=kwargs.get("waves_per_eu", 0),
|
||||
kpack=kwargs.get("kpack", 2),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user