Revert "[inductor] consolidate common GEMM triton param retrieval (#159383)"

This reverts commit e7cc42df58.

Reverted https://github.com/pytorch/pytorch/pull/159383 on behalf of https://github.com/jataylo due to sorry but rocm CI is broken due to this PR ([comment](https://github.com/pytorch/pytorch/pull/159383#issuecomment-3145604831))
This commit is contained in:
PyTorch MergeBot 2025-08-01 19:49:21 +00:00
parent c687446374
commit acad808545
10 changed files with 469 additions and 1316 deletions

View File

@ -35,10 +35,7 @@ from torch._inductor.select_algorithm import (
TritonTemplate,
TritonTemplateCaller,
)
from torch._inductor.template_heuristics import (
CUDAMMTemplateConfigHeuristic,
GemmConfig,
)
from torch._inductor.template_heuristics import CUDAConfigHeuristic, GemmConfig
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
@ -1176,7 +1173,7 @@ class TestMaxAutotune(TestCase):
# Force only decomposeK choice
with (
mock.patch(
"torch._inductor.kernel.mm.V.choices.get_mm_configs"
"torch._inductor.kernel.mm.V.choices.get_base_mm_configs"
) as base_mm_mock,
mock.patch(
"torch._inductor.kernel.mm.use_decompose_k_choice"
@ -1564,9 +1561,9 @@ class TestMaxAutotune(TestCase):
b = torch.randn(K, N, dtype=torch.float16, device="cuda", requires_grad=True)
with mock.patch(
"torch._inductor.template_registry.get_template_heuristic"
"torch._inductor.kernel.mm.V.choices.get_config_heuristics"
) as config_mock:
config_heuristics = CUDAMMTemplateConfigHeuristic()
config_heuristics = CUDAConfigHeuristic()
# Traditionally, this would be set of all possible configs
# We mock out the code path for the sake of the unit test

View File

@ -9,7 +9,6 @@ import torch
from . import config
from .codecache import write_text
from .kernel_inputs import KernelInputs # noqa: TC001
from .metrics import get_metric_table, is_metric_table_enabled
from .runtime.hints import DeviceProperties, ReductionHint
from .scheduler import BaseSchedulerNode, Scheduler, WhyNoFuse
@ -21,7 +20,6 @@ from .template_heuristics import (
ROCmConfigHeuristic,
XPUConfigHeuristic,
)
from .template_registry import get_template_heuristic
from .virtualized import V
@ -73,6 +71,58 @@ class InductorChoices:
else:
return BaseConfigHeuristic()
# GEMM configs
def get_base_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
if config.max_autotune_gemm_search_space != "EXHAUSTIVE":
return mm_heuristics.get_mm_configs()
else:
return mm_heuristics.get_exhaustive_mm_configs()
def get_extra_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_extra_mm_configs()
def get_int8_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_int8_mm_configs()
def get_mixed_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_mixed_mm_configs()
def get_persistent_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_persistent_mm_configs()
def get_scaled_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_scaled_mm_configs()
def get_scaled_persistent_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_scaled_persistent_mm_configs()
def get_mm_plus_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_mm_plus_mm_configs()
# Conv configs
def get_conv_configs(
self, device_type: Optional[str] = "cuda"
@ -81,7 +131,6 @@ class InductorChoices:
return conv_heuristics.get_conv_configs()
# Flex attention configs
# TODO(coconutruben): break out flexattention/decode configs into the new retrieval mechanism
def get_flex_attention_fwd_configs(
self, head_dim: int, dtype: torch.dtype, device_type: Optional[str] = "cuda"
) -> list[Any]:
@ -100,37 +149,6 @@ class InductorChoices:
flex_heuristics = self.get_config_heuristics(device_type)
return flex_heuristics.get_flex_decode_configs(head_dim, dtype)
def get_mm_configs(
self,
kernel_inputs: KernelInputs,
layout: Any,
template_name: str,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
"""
Get generator of template parameters for MM templates using template-specific heuristics.
Args:
kernel_inputs: MMKernelInputs containing input tensor nodes and matrix indices
layout: Output layout
template_name: Template name (e.g., "bmm", "mm", "mm_persistent_tma")
op_name: Operation name (e.g., "bmm", "baddbmm", "addmm", "mm_plus_mm")
Yields:
Template parameter dictionaries ready for maybe_append_choice
"""
input_tensors = kernel_inputs.nodes()
if len(input_tensors) < 2:
raise ValueError(f"Need at least 2 input tensors, got {len(input_tensors)}")
# Extract device_type from kernel_inputs
device_type = kernel_inputs.device_type
assert device_type is not None, "get_mm_configs requires a valid device type"
# Get the appropriate template-specific heuristic
heuristic = get_template_heuristic(template_name, device_type, op_name)
yield from heuristic.get_template_configs(kernel_inputs, layout, op_name)
def triton_kernel_kwargs(
self,
kernel_cls: type[TritonKernel],

View File

@ -6,7 +6,6 @@ from torch._dynamo.utils import counters
from torch._inductor.codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
from .. import ir, lowering as L
from ..kernel_inputs import MMKernelInputs
from ..select_algorithm import (
autotune_select_algorithm,
ExternKernelChoice,
@ -27,6 +26,8 @@ from .mm_common import (
addmm_epilogue,
is_batch_stride_largest,
mm_args,
mm_config_kwargs,
mm_options,
)
@ -39,6 +40,13 @@ def bmm_grid(b, m, n, meta, *, cdiv):
return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), b, 1)
def _is_large_block_for_cpu(m, n, k):
# Thresholds are experimentally determined to reduce Triton CPU compile times
if m > 128 or n > 128 or k > 128:
return True
return m * n > 2**12
bmm_template = TritonTemplate(
name="bmm",
grid=bmm_grid,
@ -167,14 +175,9 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
meta_mat2 = V.graph.current_node.args[1]
mat2 = may_require_contiguous(mat2, meta_mat2)
# TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
m, n, k, layout, mat1, mat2 = mm_args(
mat1, mat2, layout=layout, out_dtype=out_dtype
)
name = "bmm"
# Create MMKernelInputs for BMM at the top
kernel_inputs = MMKernelInputs([mat1, mat2])
# below is for getting an overview logging info of inductor mms
batch_size = mat1.get_size()[0] # Extract batch dimension
@ -192,27 +195,31 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
if out_dtype:
assert mat1.get_device().type == "cuda", "out_dtype is only supported for CUDA"
aten_func = aten_bmm_dtype.bind(
kernel_inputs.nodes(), layout, out_dtype=out_dtype
)
aten_func = aten_bmm_dtype.bind((mat1, mat2), layout, out_dtype=out_dtype)
else:
aten_func = aten_bmm.bind(kernel_inputs.nodes(), layout)
aten_func = aten_bmm.bind((mat1, mat2), layout)
# options to tune from
choices = [aten_func] if use_aten_gemm_kernels() else []
device_type = ir.get_device_type(mat1)
bmm_configs = V.choices.get_base_mm_configs(device_type)
dtype = mat1.get_dtype()
if use_triton_template(layout):
# TODO: add out_dtype support for Triton Template
assert out_dtype is None, "out_dtype is not supported for Triton"
for kwargs in V.choices.get_mm_configs(
kernel_inputs, layout, bmm_template.name, name
for config in bmm_configs(
m,
n,
k,
**mm_config_kwargs(device_type, _is_large_block_for_cpu, dtype.itemsize),
):
bmm_template.maybe_append_choice(
choices,
input_nodes=kernel_inputs.nodes(),
input_nodes=(mat1, mat2),
layout=layout,
**kwargs,
**mm_options(config, m, n, k, layout),
)
_, is_nonzero = _is_static_problem(layout)
batch_stride_largest = is_batch_stride_largest(mat1, mat2, layout)
@ -220,13 +227,11 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
batch_stride_largest
and is_nonzero
and use_cutlass_template(layout, m, n, k)
and _use_cutlass_for_op(name)
and _use_cutlass_for_op("bmm")
):
from ..codegen.cuda.gemm_template import CUTLASS3xGemmTemplate
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
choices, layout, kernel_inputs.nodes()
) # type: ignore[arg-type]
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2]) # type: ignore[arg-type]
if use_cpp_bmm_template(layout, mat1, mat2):
from ..codegen.cpp_bmm_template import CppBmmTemplate
@ -234,23 +239,19 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
CppBmmTemplate.add_choices(
choices,
layout,
kernel_inputs.nodes(),
[mat1, mat2],
)
if use_ck_gemm_template(layout, m, n, k):
CKGemmTemplate.add_ck_gemm_choices(choices, layout, kernel_inputs.nodes())
CKGemmTemplate.add_ck_gemm_choices(choices, layout, [mat1, mat2])
return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout)
return autotune_select_algorithm("bmm", choices, [mat1, mat2], layout)
@L.register_lowering(aten.baddbmm)
def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
# TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
m, n, k, layout, mat1, mat2, inp = mm_args(mat1, mat2, inp, layout=layout)
# Create MMKernelInputs for BadDBMM at the top
kernel_inputs = MMKernelInputs([inp, mat1, mat2])
# below is for getting an overview logging info of inductor mms
batch_size = mat1.get_size()[0]
counters["aten_mm_info"][f"aten.baddbmm_{batch_size}_{m}_{n}_{k}"] += 1
@ -265,26 +266,29 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
inp.get_dtype(),
layout,
)
name = "baddbmm"
# options to tune from
choices = (
[aten_baddbmm.bind(kernel_inputs.nodes(), layout, alpha=alpha, beta=beta)]
[aten_baddbmm.bind((inp, mat1, mat2), layout, alpha=alpha, beta=beta)]
if use_aten_gemm_kernels()
else []
)
device_type = ir.get_device_type(mat1)
bmm_configs = V.choices.get_base_mm_configs(device_type)
if use_triton_template(layout):
for kwargs in V.choices.get_mm_configs(
kernel_inputs, layout, bmm_template.name, name
for config in bmm_configs(
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
):
bmm_template.maybe_append_choice(
choices,
input_nodes=kernel_inputs.nodes(),
input_nodes=(inp, mat1, mat2),
layout=layout,
**kwargs,
**mm_options(config, m, n, k, layout),
prefix_args=1,
epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
epilogue_fn_hash=str(["addmm_epilogue", layout.dtype, alpha, beta]),
)
return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout)
return autotune_select_algorithm("baddbmm", choices, [inp, mat1, mat2], layout)

View File

@ -29,6 +29,7 @@ from ..utils import (
use_triton_template,
)
from ..virtualized import V
from .mm_common import mm_config_kwargs
if TYPE_CHECKING:
@ -60,6 +61,13 @@ def conv3d_grid(n, c, d, h, w, meta, *, cdiv):
)
def _is_large_block_for_cpu(m, n, k):
# Thresholds are experimentally determined to reduce Triton CPU compile times
if m > 256 or n > 256 or k > 256:
return True
return m * n * k > 2**17
LOOP_BODY_2D = """
idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H
idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W
@ -595,6 +603,7 @@ def convolution(
sympy_product([x.get_size()[0], *x.get_size()[2:]]),
out_chan,
in_chan,
**mm_config_kwargs(device_type, _is_large_block_for_cpu),
):
if ndim == 2:
conv2d_template.maybe_append_choice(

View File

@ -19,13 +19,12 @@ from torch._inductor.virtualized import V
from torch.fx.experimental.proxy_tensor import make_fx
from torch.torch_version import TorchVersion
from .. import config as inductor_config
from .. import config as inductor_config, ir
from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate
from ..codegen.rocm.ck_tile_universal_gemm_template import CKTileGemmTemplate
from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
from ..codegen.subgraph import SubgraphTemplate
from ..ir import FlexibleLayout, is_triton
from ..kernel_inputs import MMKernelInputs
from ..lowering import (
add_layout_constraint,
constrain_to_fx_strides,
@ -55,9 +54,13 @@ from .mm_common import (
_is_static_problem,
addmm_epilogue,
mm_args,
mm_config_kwargs,
mm_grid,
mm_options,
persistent_mm_grid,
persistent_mm_options,
scale_mm_epilogue,
scaled_mm_options,
)
@ -584,6 +587,11 @@ def _is_int8_mat(mat):
return mat.get_dtype() in (torch.int8, torch.uint8)
def _is_large_block_for_cpu(m, n, k):
# Thresholds are experimentally determined to reduce Triton CPU compile times
return m * n > 2**13
@functools.lru_cache
def using_b200() -> bool:
"""Returns true if the device is a NVIDIA B200, otherwise returns false."""
@ -653,14 +661,10 @@ def tuned_mm(mat1, mat2, *, layout=None):
"""
Lowering for autotuning aten.mm with different backends (Aten, Triton, CUTLASS, etc.)
"""
# TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
static_shape, is_nonzero = _is_static_problem(layout)
device_type = ir.get_device_type(mat1)
name = "mm"
# Create MMKernelInputs for standard MM at the top
kernel_inputs = MMKernelInputs([mat1, mat2])
# below is for getting an overview logging info of inductor mms
counters["aten_mm_info"][f"aten.mm_{m}_{n}_{k}"] += 1
log.info(
@ -681,38 +685,48 @@ def tuned_mm(mat1, mat2, *, layout=None):
# options to tune from
choices = (
[aten_mm.bind(kernel_inputs.nodes(), aten_layout)]
if use_aten_gemm_kernels()
else []
[aten_mm.bind((mat1, mat2), aten_layout)] if use_aten_gemm_kernels() else []
)
static_shape, is_nonzero = _is_static_problem(layout)
mm_configs = V.choices.get_base_mm_configs(device_type)
persistent_mm_configs = V.choices.get_persistent_mm_configs(device_type)
extra_mm_configs = V.choices.get_extra_mm_configs(device_type)
dtype = mat1.get_dtype()
if is_nonzero and use_triton_template(layout):
# Get template params using the new unified function
for kwargs in V.choices.get_mm_configs(
kernel_inputs, layout, mm_template.name, "mm"
for config in mm_configs(
m,
n,
k,
**mm_config_kwargs(device_type, _is_large_block_for_cpu, dtype.itemsize),
):
mm_template.maybe_append_choice(
choices,
input_nodes=kernel_inputs.nodes(),
input_nodes=(mat1, mat2),
layout=layout,
**kwargs,
**mm_options(config, m, n, k, layout),
)
if use_triton_tma_template(mat1, mat2):
# Get TMA template params using the new unified function
for kwargs in V.choices.get_mm_configs(
kernel_inputs, layout, persistent_tma_mm_template.name, "mm"
for config in persistent_mm_configs(
m,
n,
k,
**mm_config_kwargs(
device_type, _is_large_block_for_cpu, dtype.itemsize
),
):
persistent_tma_mm_template.maybe_append_choice(
choices,
input_nodes=kernel_inputs.nodes(),
input_nodes=(mat1, mat2),
layout=layout,
workspace_arg=get_tma_workspace_arg(
num_tma_descriptors=2,
device=mat1.get_device(),
),
**kwargs,
**mm_options(config, m, n, k, layout),
**persistent_mm_options(mat1, mat2),
)
from torch._inductor.ir import get_free_symbols
@ -762,20 +776,18 @@ def tuned_mm(mat1, mat2, *, layout=None):
and use_cutlass_template(layout, m, n, k)
and _use_cutlass_for_op("mm")
):
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
choices, layout, kernel_inputs.nodes()
)
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2])
if is_nonzero and use_ck_gemm_template(layout, m, n, k):
CKGemmTemplate.add_ck_gemm_choices(choices, layout, kernel_inputs.nodes())
CKGemmTemplate.add_ck_gemm_choices(choices, layout, [mat1, mat2])
if is_nonzero and use_ck_tile_gemm_template(layout, m, n, k):
CKTileGemmTemplate.add_choices(choices, layout, kernel_inputs.nodes())
CKTileGemmTemplate.add_choices(choices, layout, [mat1, mat2])
if use_cpp_gemm_template(layout, mat1, mat2):
CppGemmTemplate.add_choices(
choices,
layout,
kernel_inputs.nodes(),
[mat1, mat2],
)
input_nodes = [mat1, mat2]
@ -789,20 +801,14 @@ def tuned_mm(mat1, mat2, *, layout=None):
if use_aten_gemm_kernels():
always_included.append("extern_mm")
num_choices_before_extra_configs = len(choices)
for kwargs in V.choices.get_mm_configs(
# TODO(coconutruben): remove once we deprecate ah
# mm-extra is a hack to keep the ah functionality alive
# while we transition to the unified kwargs retrieval
kernel_inputs,
layout,
"mm-ah",
"mm",
for config in extra_mm_configs(
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
):
mm_template.maybe_append_choice(
choices,
input_nodes=kernel_inputs.nodes(),
input_nodes=(mat1, mat2),
layout=layout,
**kwargs,
**mm_options(config, m, n, k, layout),
)
# using AutoHeuristic for ranking
@ -832,16 +838,13 @@ def tuned_mm(mat1, mat2, *, layout=None):
choices = choices[:num_choices_before_extra_configs]
for k in inductor_config.external_matmul:
choices.append(
lazy_register_extern_choice(k).bind(kernel_inputs.nodes(), layout)
)
choices.append(lazy_register_extern_choice(k).bind((mat1, mat2), layout))
return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout)
return autotune_select_algorithm(name, choices, [mat1, mat2], layout)
@register_lowering(aten._int_mm, type_promotion_kind=None)
def tuned_int_mm(mat1, mat2, *, layout=None):
# TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
m, n, k, layout, mat1, mat2 = mm_args(
mat1, mat2, layout=layout, out_dtype=torch.int32
)
@ -858,6 +861,8 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
layout,
)
device_type = ir.get_device_type(mat1)
static_shape, is_nonzero = _is_static_problem(layout)
use_cutlass = static_shape and is_nonzero and use_cutlass_template(layout, m, n, k)
@ -865,37 +870,33 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
[aten__int_mm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
)
# Create MMKernelInputs for Int MM
kernel_inputs = MMKernelInputs([mat1, mat2])
if use_cutlass and _use_cutlass_for_op("int_mm"):
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
choices, layout, kernel_inputs.nodes(), fuseable=True, non_fuseable=True
choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True
)
int8_mm_configs = V.choices.get_int8_mm_configs(device_type)
if is_nonzero and use_triton_template(layout, enable_int32=True):
for kwargs in V.choices.get_mm_configs(
kernel_inputs, layout, mm_template.name, "int_mm"
for config in int8_mm_configs(
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
):
mm_template.maybe_append_choice(
choices,
input_nodes=kernel_inputs.nodes(),
input_nodes=(mat1, mat2),
layout=layout,
**kwargs,
**mm_options(config, m, n, k, layout),
)
return autotune_select_algorithm("int_mm", choices, kernel_inputs.nodes(), layout)
return autotune_select_algorithm("int_mm", choices, [mat1, mat2], layout)
@register_lowering(aten.addmm, type_promotion_kind=None)
def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
# TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
device_type = ir.get_device_type(mat1)
m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout)
static_shape, is_nonzero = _is_static_problem(layout)
# Create MMKernelInputs for AddMM at the top
kernel_inputs = MMKernelInputs([inp_expanded, mat1, mat2])
# below is for getting an overview logging info of inductor mms
counters["aten_mm_info"][f"aten.addmm_{m}_{n}_{k}"] += 1
log.info(
@ -922,9 +923,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
choices = (
[
aten_addmm.bind(
# TODO(coconutruben): replace with kernel_inputs.nodes()
# once that supports the unexpanded nodes as well
[inp, mat1, mat2],
(inp, mat1, mat2),
layout,
alpha=alpha,
beta=beta,
@ -933,19 +932,12 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
if use_aten_gemm_kernels()
else []
)
return autotune_select_algorithm(
# TODO(coconutruben): replace with kernel_inputs.nodes()
# once that supports the unexpanded nodes as well
"addmm",
choices,
[inp, mat1, mat2],
layout,
)
return autotune_select_algorithm("addmm", choices, [inp, mat1, mat2], layout)
choices = (
[
aten_addmm.bind(
kernel_inputs.nodes(),
(inp_expanded, mat1, mat2),
layout,
alpha=alpha,
beta=beta,
@ -965,42 +957,50 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
choices.insert(
0,
aten_bias_addmm.bind(
kernel_inputs.nodes(),
layout,
alpha=alpha,
beta=beta,
(inp_expanded, mat1, mat2), layout, alpha=alpha, beta=beta
),
)
mm_configs = V.choices.get_base_mm_configs(device_type)
persistent_mm_configs = V.choices.get_persistent_mm_configs(device_type)
dtype = mat1.get_dtype()
if is_nonzero and use_triton_template(layout):
# Get template params using the new unified function
for kwargs in V.choices.get_mm_configs(
kernel_inputs, layout, mm_template.name, "addmm"
for config in mm_configs(
m,
n,
k,
**mm_config_kwargs(device_type, _is_large_block_for_cpu, dtype.itemsize),
):
mm_template.maybe_append_choice(
choices,
input_nodes=kernel_inputs.nodes(),
input_nodes=(inp_expanded, mat1, mat2),
layout=layout,
**kwargs,
**mm_options(config, m, n, k, layout),
prefix_args=1,
epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
epilogue_fn_hash=str(["addmm_epilogue", layout.dtype, alpha, beta]),
)
if use_triton_tma_template(mat1, mat2):
# Get TMA template params using the new unified function
for kwargs in V.choices.get_mm_configs(
kernel_inputs, layout, persistent_tma_mm_template.name, "addmm"
for config in persistent_mm_configs(
m,
n,
k,
**mm_config_kwargs(
device_type, _is_large_block_for_cpu, dtype.itemsize
),
):
persistent_tma_mm_template.maybe_append_choice(
choices,
input_nodes=kernel_inputs.nodes(),
input_nodes=(inp_expanded, mat1, mat2),
layout=layout,
workspace_arg=get_tma_workspace_arg(
num_tma_descriptors=2,
device=mat1.get_device(),
),
**kwargs,
**mm_options(config, m, n, k, layout),
**persistent_mm_options(mat1, mat2),
prefix_args=1,
epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
)
@ -1013,20 +1013,17 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
choices,
layout,
# reorder here because CUTLASS expects (x, w, bias) but torch
# is bias, x, w
kernel_inputs.nodes(reorder=[1, 2, 0]),
[mat1, mat2, inp_expanded],
alpha=alpha,
beta=beta,
input_reorder=[2, 0, 1],
)
if is_nonzero and use_ck_gemm_template(layout, m, n, k):
CKGemmTemplate.add_ck_gemm_choices(
choices,
layout,
# reorder here because CK expects (x, w, bias) but torch
# is bias, x, w
kernel_inputs.nodes(reorder=[1, 2, 0]),
[mat1, mat2, inp_expanded],
alpha=alpha,
beta=beta,
input_reorder=[2, 0, 1],
@ -1036,13 +1033,15 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
CppGemmTemplate.add_choices(
choices,
layout,
kernel_inputs.nodes(),
[inp_expanded, mat1, mat2],
alpha=alpha,
beta=beta,
has_bias=True,
)
return autotune_select_algorithm("addmm", choices, kernel_inputs.nodes(), layout)
return autotune_select_algorithm(
"addmm", choices, [inp_expanded, mat1, mat2], layout
)
@register_lowering(aten._sparse_semi_structured_mm, type_promotion_kind=None)
@ -1090,7 +1089,7 @@ def tuned_sparse_semi_structured_mm(
)
return autotune_select_algorithm(
"sparse_semi_structured_mm", choices, (mat1, mat1_meta, mat2), layout
"sparse_semi_structured_mm", choices, [mat1, mat1_meta, mat2], layout
)
@ -1124,7 +1123,6 @@ def tuned_scaled_mm(
Returns:
Tensor: The result of the scaled matrix multiplication
"""
# TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
m, n, k, layout, mat_a, mat_b = mm_args(
mat_a, mat_b, layout=layout, out_dtype=out_dtype
)
@ -1140,6 +1138,7 @@ def tuned_scaled_mm(
layout,
)
device_type = ir.get_device_type(mat_a)
check_supported_striding(mat_a, mat_b)
scale_a_real, scale_b_real = realize_inputs(scale_a, scale_b)
@ -1166,51 +1165,59 @@ def tuned_scaled_mm(
_, is_nonzero = _is_static_problem(layout)
# Prepare triton input nodes and create kernel_inputs at the top
triton_input_nodes: list[Any]
if bias and len(mat_b.get_size()) == len(bias.get_size()) + 1:
# Need to unsqueeze bias from [N] -> [1, N]
triton_bias = L[aten.unsqueeze](bias, 0)
else:
triton_bias = bias
if len(scale_a.get_size()) == 0 or len(scale_b.get_size()) == 0:
assert len(scale_a.get_size()) == len(scale_b.get_size())
# Need to unsqueeze scale from [] -> [1, 1]
triton_scale_a = L[aten.unsqueeze](L[aten.unsqueeze](scale_a, 0), 1)
triton_scale_b = L[aten.unsqueeze](L[aten.unsqueeze](scale_b, 0), 1)
else:
triton_scale_a = scale_a
triton_scale_b = scale_b
if bias:
triton_input_nodes = [
mat_a,
mat_b,
triton_scale_a,
triton_scale_b,
triton_bias,
]
suffix_args = 3
else:
triton_input_nodes = [mat_a, mat_b, triton_scale_a, triton_scale_b]
suffix_args = 2
# Create MMKernelInputs for Scaled MM (matrices are at indices 0, 1)
kernel_inputs = MMKernelInputs(triton_input_nodes, mat1_idx=0, mat2_idx=1)
scaled_mm_configs = V.choices.get_scaled_mm_configs(device_type)
scaled_persistent_mm_configs = V.choices.get_scaled_persistent_mm_configs(
device_type
)
if is_nonzero and use_triton_template(layout, enable_float8=True):
triton_input_nodes: tuple[Any, ...]
if bias and len(mat_b.get_size()) == len(bias.get_size()) + 1:
# Need to unsqueeze bias from [N] -> [1, N]
triton_bias = L[aten.unsqueeze](bias, 0)
else:
triton_bias = bias
if len(scale_a.get_size()) == 0 or len(scale_b.get_size()) == 0:
assert len(scale_a.get_size()) == len(scale_b.get_size())
# Need to unsqueeze scale from [] -> [1, 1]
triton_scale_a = L[aten.unsqueeze](L[aten.unsqueeze](scale_a, 0), 1)
triton_scale_b = L[aten.unsqueeze](L[aten.unsqueeze](scale_b, 0), 1)
else:
triton_scale_a = scale_a
triton_scale_b = scale_b
if bias:
triton_input_nodes = (
mat_a,
mat_b,
triton_scale_a,
triton_scale_b,
triton_bias,
)
suffix_args = 3
else:
triton_input_nodes = (mat_a, mat_b, triton_scale_a, triton_scale_b)
suffix_args = 2
# TODO (paulzhan): There is no template that exists for bias and TMA
# Don't run tma template currently if bias exists
if use_triton_tma_template(mat_a, mat_b) and not bias:
# Get TMA template params using the new unified function
for kwargs in V.choices.get_mm_configs(
kernel_inputs, layout, scaled_mm_device_tma_template.name, "scaled_mm"
):
kwargs["USE_FAST_ACCUM"] = use_fast_accum
for config in scaled_persistent_mm_configs(m, n, k):
kwargs = scaled_mm_options(
config,
m,
n,
k,
layout,
scale_a,
scale_b,
use_fast_accum,
device_tma=True,
)
scaled_mm_device_tma_template.maybe_append_choice(
choices,
input_nodes=kernel_inputs.nodes(),
input_nodes=triton_input_nodes,
layout=layout,
workspace_arg=get_tma_workspace_arg(
num_tma_descriptors=2,
@ -1219,11 +1226,7 @@ def tuned_scaled_mm(
**kwargs,
)
# Get template params using the new unified function
for kwargs in V.choices.get_mm_configs(
kernel_inputs, layout, mm_template.name, "scaled_mm"
):
kwargs["USE_FAST_ACCUM"] = use_fast_accum
for config in scaled_mm_configs(m, n, k):
if V.graph.sizevars.guard_or_false(sympy.Le(k, 16)):
# Triton crashes however uncommon for real workloads
continue
@ -1233,10 +1236,13 @@ def tuned_scaled_mm(
if using_b200() and V.graph.sizevars.guard_or_false(sympy.Lt(k, 32)):
continue
kwargs = scaled_mm_options(
config, m, n, k, layout, scale_a, scale_b, use_fast_accum
)
# possibly appends a TritonTemplateCaller to choices
mm_template.maybe_append_choice(
choices,
input_nodes=kernel_inputs.nodes(),
input_nodes=triton_input_nodes,
layout=layout,
**kwargs,
suffix_args=suffix_args,
@ -1252,16 +1258,14 @@ def tuned_scaled_mm(
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
choices,
layout,
kernel_inputs.nodes(), # type: ignore[arg-type]
input_nodes, # type: ignore[arg-type]
use_fast_accum=use_fast_accum, # type: ignore[arg-type]
)
if is_nonzero and use_ck_gemm_template(layout, m, n, k):
CKGemmTemplate.add_ck_gemm_choices(choices, layout, kernel_inputs.nodes())
CKGemmTemplate.add_ck_gemm_choices(choices, layout, input_nodes)
return autotune_select_algorithm(
"scaled_mm", choices, kernel_inputs.nodes(), layout
)
return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout)
@functools.cache

View File

@ -3,13 +3,17 @@ import logging
from collections.abc import Sequence
from typing import Any
import sympy
import torch
from torch._inductor.select_algorithm import realize_inputs, SymbolicGridFn
from torch._inductor.utils import sympy_product
from torch._inductor.virtualized import V
from .. import config as inductor_config
from ..codegen.wrapper import PythonWrapperCodegen
from ..ir import _IntLike, Layout, TensorBox
from ..utils import get_num_sms, TMA_DESCRIPTOR_SIZE
log = logging.getLogger(__name__)
@ -45,6 +49,96 @@ def acc_type(dtype):
return f"tl.{dtype}".replace("torch.", "")
def mm_options(config, sym_m, sym_n, sym_k, layout):
"""
Common options to matmul triton templates.
"""
even_k_symbolic = (
# it isn't worth guarding on this
sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"]
)
allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and (
not inductor_config.force_same_precision
or ((sym_m % 16) == 0 and (sym_n % 16) == 0 and (sym_k % 8) == 0)
)
options_dict = dict(
EVEN_K=even_k_symbolic,
ALLOW_TF32=allow_tf32,
USE_FAST_ACCUM=False, # Option for _scaled_mm
ACC_TYPE=acc_type(layout.dtype),
num_stages=config.num_stages,
num_warps=config.num_warps,
**config.kwargs,
)
# If GROUP_M not specified then default to 8
if "GROUP_M" not in config.kwargs:
group_m = config.kwargs.get("GROUP_M", 8)
options_dict["GROUP_M"] = group_m
return options_dict
def tma_options() -> dict[str, Any]:
from torch.utils._triton import has_triton_stable_tma_api
return {"TMA_EXPERIMENTAL_API": not has_triton_stable_tma_api()}
def persistent_mm_options(mat1, mat2):
res = {
"A_ROW_MAJOR": not mat1.layout.is_transposed(),
"B_ROW_MAJOR": not mat2.layout.is_transposed(),
"NUM_SMS": get_num_sms(),
"TMA_SIZE": TMA_DESCRIPTOR_SIZE,
}
res.update(tma_options())
return res
def scaled_mm_options( # type: ignore[no-untyped-def]
config, # triton.Config
sym_m: sympy.core.numbers.Integer,
sym_n: sympy.core.numbers.Integer,
sym_k: sympy.core.numbers.Integer,
layout: Layout,
scale_a,
scale_b,
use_fast_accum: bool,
device_tma: bool = False,
) -> dict[str, Any]:
def are_compatible_scales(size_a, size_b) -> bool:
# Same sized scales are compatible
if len(size_a) == len(size_b):
return True
# Both need to be scalars or len(1) tensors
if len(size_a) <= 1 and len(size_b) <= 1:
return True
return False
size_a, size_b = scale_a.get_size(), scale_b.get_size()
assert are_compatible_scales(size_a, size_b), (
"Expect scale_a and scale_b to be either both scalars (including single-element tensors) "
f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}."
)
mm_template_options = mm_options(config, sym_m, sym_n, sym_k, layout)
mm_template_options["ACC_TYPE"] = "tl.float32"
mm_template_options["USE_FAST_ACCUM"] = use_fast_accum
mm_template_options["SCALING_ROWWISE"] = len(size_a) == 2
if device_tma:
mm_template_options["TMA_SIZE"] = TMA_DESCRIPTOR_SIZE
mm_template_options["NUM_SMS"] = get_num_sms()
mm_template_options.update(tma_options())
return mm_template_options
def mm_args(
mat1,
mat2,
@ -87,6 +181,20 @@ def mm_args(
return [m, n, k, layout, mat1, mat2, *others]
def mm_config_kwargs(device, exclude_condition, dtype_size=None):
if device == "cpu":
return {
"scale": 0.5,
"exclude": exclude_condition,
}
if dtype_size and inductor_config.max_autotune_gemm_search_space == "EXHAUSTIVE":
return {
"dtype_size": dtype_size,
}
return {}
def addmm_epilogue(dtype, alpha, beta):
def epilogue(acc, bias):
if alpha != 1:

View File

@ -1,10 +1,8 @@
# mypy: allow-untyped-defs
import logging
import torch
from ..kernel_inputs import MMKernelInputs
from .. import ir
from ..lowering import lowerings
from ..select_algorithm import (
autotune_select_algorithm,
@ -13,11 +11,9 @@ from ..select_algorithm import (
)
from ..utils import use_aten_gemm_kernels, use_triton_template
from ..virtualized import V
from .mm_common import mm_args, mm_grid
from .mm_common import mm_args, mm_grid, mm_options
log = logging.getLogger(__name__)
aten = torch.ops.aten
aten_mm_plus_mm = ExternKernelChoice(
@ -123,9 +119,9 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
"""
Computes mm(mat1, mat2) + mm(mat3, mat4)
"""
# TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
m1, n1, k1, layout1, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
m2, n2, _, layout2, mat3, mat4 = mm_args(mat3, mat4, layout=layout)
device_type = ir.get_device_type(mat1)
# Optimization is optional, because we can always just not do the fusion
if (
@ -144,34 +140,27 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
lowerings[aten.mm](mat1, mat2), lowerings[aten.mm](mat3, mat4)
)
# Create MMKernelInputs for MM Plus MM (matrices are at indices 0, 1 for first pair)
# Note: This is a special case with 4 matrices, but we use the first pair for M, N, K extraction
kernel_inputs = MMKernelInputs([mat1, mat2, mat3, mat4], mat1_idx=0, mat2_idx=1)
assert layout1 == layout2
# options to tune from
choices = (
[aten_mm_plus_mm.bind(kernel_inputs.nodes(), layout1)]
[aten_mm_plus_mm.bind((mat1, mat2, mat3, mat4), layout1)]
if use_aten_gemm_kernels()
else []
)
mm_configs = V.choices.get_mm_plus_mm_configs(device_type)
if use_triton_template(layout1):
# Get template params using the new unified function
for kwargs in V.choices.get_mm_configs(
kernel_inputs, layout1, mm_plus_mm_template.name, "mm_plus_mm"
):
# Apply BLOCK_K constraint specific to mm_plus_mm
for config in mm_configs():
# see https://github.com/triton-lang/triton/issues/1298
# BLOCK_K = K causes llvm error
if V.graph.sizevars.statically_known_lt(kwargs.get("BLOCK_K", k1), k1):
if V.graph.sizevars.statically_known_lt(config.kwargs["BLOCK_K"], k1):
mm_plus_mm_template.maybe_append_choice(
choices,
input_nodes=kernel_inputs.nodes(),
input_nodes=(mat1, mat2, mat3, mat4),
layout=layout1,
**kwargs,
**mm_options(config, m1, n1, k1, layout1),
)
return autotune_select_algorithm(
"mm_plus_mm", choices, kernel_inputs.nodes(), layout1
"mm_plus_mm", choices, [mat1, mat2, mat3, mat4], layout1
)

View File

@ -1,237 +0,0 @@
from __future__ import annotations
from typing import Any, Optional, TYPE_CHECKING
import torch
import torch._inductor.config
from torch._inductor import ir
from torch._inductor.virtualized import V
if TYPE_CHECKING:
from collections.abc import Sequence
import sympy
class KernelInputs:
"""
Class to store and provide access to input nodes for kernels.
This class takes in a tuple of input nodes and provides methods to access
information about these nodes, such as their device type and device.
"""
def __init__(self, input_nodes: list[Any]):
"""
Initialize with a tuple of input nodes.
Args:
input_nodes: A tuple of input nodes to store
"""
self._input_nodes = input_nodes
self._device_name: Optional[str] = None
assert len(input_nodes) > 0, "Expected at least one input node"
def nodes(self, reorder: Optional[Sequence[int]] = None) -> list[Any]:
"""
Return the stored input nodes, optionally reordered.
Args:
reorder: Optional sequence of indices to reorder the nodes.
For example, (2, 0, 1) would return nodes in that order.
Returns:
The tuple of input nodes, optionally reordered
"""
if reorder is None:
return self._input_nodes
assert len(self._input_nodes) == len(reorder), (
f"Reorder length mismatch: {len(self._input_nodes)} vs {len(reorder)}"
)
return [self._input_nodes[i] for i in reorder]
@property
def device_type(self) -> Optional[str]:
"""
Get the device type of the first node.
Returns:
The device type (e.g., 'cuda', 'cpu')
"""
return ir.get_device_type(self._input_nodes[0])
def device(self) -> torch.device:
"""
Get the device of the first node.
Returns:
The device of the first node
"""
return self._input_nodes[0].get_device()
def device_name(self) -> Optional[str]:
"""
Get the device name information.
Returns:
A tuple of (gpu_name, vendor, model)
"""
if self._device_name is None:
device = self.device()
if self.device_type == "cuda":
device_properties = torch.cuda.get_device_properties(device)
self._device_name = device_properties.gcnArchName
return self._device_name
def shapes_symbolic(self) -> tuple[tuple[Any, ...], ...]:
"""
Get the symbolic shapes of all input nodes.
Returns:
A tuple of shape tuples for each input node
"""
return tuple(node.get_size() for node in self._input_nodes)
def shapes_hinted(self) -> tuple[tuple[int, ...], ...]:
"""
Get the size hints for shapes of all input nodes.
Returns:
A tuple of shape tuples with integer hints for each input node
"""
return tuple(
V.graph.sizevars.size_hints(
node.get_size(),
fallback=torch._inductor.config.unbacked_symint_fallback,
)
for node in self._input_nodes
)
def strides_symbolic(self) -> tuple[tuple[sympy.Integer, ...], ...]:
"""
Get the symbolic strides of all input nodes.
Returns:
A tuple of stride tuples for each input node
"""
return tuple(node.get_stride() for node in self._input_nodes)
def strides_hinted(self) -> tuple[tuple[int, ...], ...]:
"""
Get the size hints for strides of all input nodes.
Returns:
A tuple of stride tuples with integer hints for each input node
"""
return tuple(
V.graph.sizevars.size_hints(
node.get_stride(),
fallback=torch._inductor.config.unbacked_symint_fallback,
)
for node in self._input_nodes
)
def dtypes(self) -> tuple[torch.dtype, ...]:
"""
Get the dtypes of all input nodes.
Returns:
A tuple of dtypes for each input node
"""
return tuple(node.get_dtype() for node in self._input_nodes)
def dtype(self, idx: int = 0) -> torch.dtype:
"""
Get the dtype of a specific input node.
Args:
idx: Index of the node to get the dtype from (default: 0)
Returns:
The dtype of the specified input node
"""
return self._input_nodes[idx].get_dtype()
class MMKernelInputs(KernelInputs):
"""
Specialized KernelInputs for matrix multiplication operations.
Provides additional methods to access M, N, K dimensions.
"""
def __init__(self, input_nodes: list[Any], mat1_idx: int = -2, mat2_idx: int = -1):
"""
Initialize with a tuple of input nodes.
By default, we assume the last 2 input nodes are mat1 and mat2, but
the caller can adjust when necessary
"""
super().__init__(input_nodes)
# for mm, we need at least 2 nodes, and we need to know which nodes
# are the main matrixes e.g. addmm is (bias, mat1, mat2) whereas others
# might be (mat1, mat2, scale), etc.
assert len(self._input_nodes) >= 2, "Expected at least 2 input nodes"
# Adjust assertions to handle negative indices
m1_idx, m2_idx = mat1_idx, mat2_idx
if mat1_idx < 0:
m1_idx += len(input_nodes)
if mat2_idx < 0:
m2_idx += len(input_nodes)
assert 0 <= m1_idx < len(input_nodes), f"Invalid mat1_idx: {mat1_idx}"
assert 0 <= m1_idx < len(input_nodes), f"Invalid mat2_idx: {mat2_idx}"
self._mat1_idx = mat1_idx
self._mat2_idx = mat2_idx
def mnk_symbolic(
self,
) -> tuple[sympy.Integer, sympy.Integer, sympy.Integer]:
"""
Get the symbolic M, N, K dimensions for matrix multiplication.
Handles both 2D (MM) and 3D (BMM) tensors.
M is extracted from the second-to-last dimension of the first operand (mat1).
N is extracted from the last dimension of the second operand (mat2).
K is extracted from the last dimension of the first operand (mat1).
Returns:
A tuple of (M, N, K) dimensions
"""
mat1 = self.nodes()[self._mat1_idx]
mat2 = self.nodes()[self._mat2_idx]
m = mat1.get_size()[-2] # M from second-to-last dimension of mat1
k = mat1.get_size()[-1] # K from last dimension of mat1
n = mat2.get_size()[-1] # N from last dimension of mat2
# Ensure K dimensions match between operands
k0 = mat2.get_size()[-2] # K from second-to-last dimension of mat2
V.graph.sizevars.check_equals(k, k0)
return (m, n, k)
def mnk_hinted(self) -> tuple[int, int, int]:
"""
Get the hinted M, N, K dimensions for matrix multiplication.
Handles both 2D (MM) and 3D (BMM) tensors.
Uses shapes_hinted from the base class to get integer hints for dimensions.
Returns:
A tuple of (M, N, K) dimensions as integers
"""
hinted_shapes = self.shapes_hinted()
mat1_shape = hinted_shapes[self._mat1_idx]
mat2_shape = hinted_shapes[self._mat2_idx]
m = mat1_shape[-2] # M from second-to-last dimension of mat1
k = mat1_shape[-1] # K from last dimension of mat1
n = mat2_shape[-1] # N from last dimension of mat2
# Ensure K dimensions match between operands
k_check = mat2_shape[-2] # K from second-to-last dimension of mat2
assert k == k_check, f"K dimensions don't match: {k} vs {k_check}"
return (m, n, k)

View File

@ -7,16 +7,11 @@ from functools import partial
from threading import Lock
from typing import Any, Callable, Optional, TYPE_CHECKING
import sympy
import torch
from torch.utils._ordered_set import OrderedSet
from torch.utils._triton import has_triton_stable_tma_api
from . import config, config as inductor_config
from .kernel_inputs import KernelInputs, MMKernelInputs
from .template_registry import register_template_heuristic
from .utils import get_backend_num_stages, get_num_sms, TMA_DESCRIPTOR_SIZE
from . import config
from .utils import get_backend_num_stages
from .virtualized import V
@ -152,9 +147,6 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
"""
def __init__(self) -> None:
# Whether the heuristic is used for int8. Use this when the heuristic is int8 exclusive
# but prefer the preprocess_mm_configs argument when it's used for both
self.has_int8_tensor: bool = False
# List of dictionaries to store the kernel configs. Configs that evaluate to true
# will be utilised on the target platform. The configs are as follows:
# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
@ -475,7 +467,7 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
configs: list[BaseConfig],
scale: float,
has_int8_tensor: bool,
exclude: Callable[[sympy.Integer, sympy.Integer, sympy.Integer], bool],
exclude: Callable[[int, int, int], bool],
hint_override: Optional[int] = None,
) -> list[BaseConfig]:
"""
@ -484,7 +476,7 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
from .runtime.runtime_utils import next_power_of_2
min_block_size = 16
min_block_size_k = 32 if (has_int8_tensor or self.has_int8_tensor) else 16
min_block_size_k = 32 if has_int8_tensor else 16
scaled_configs = []
for hint_override in [None] + config.multi_kernel_hints:
@ -569,13 +561,6 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
return pruned_configs
def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]:
"""
Filter configs based on specific requirements.
Subclasses can override this to implement custom filtering logic.
"""
return configs
def preprocess_mm_configs(
self,
m: int,
@ -583,17 +568,14 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
k: int,
configs: list[BaseConfig],
has_int8_tensor: bool = False,
scale: float = 1.0,
exclude: Callable[
[sympy.Integer, sympy.Integer, sympy.Integer], bool
] = lambda m, n, k: False,
scale: int = 1,
exclude: Callable[[int, int, int], bool] = lambda m, n, k: False,
dtype_size: int = 0,
op_name: str = "mm", # For preprocessing overrides e.g. on CPU
) -> Generator[TritonConfig, None, None]:
configs = self._filter_configs(configs)
scaled_configs = self._scale_mm_configs(
m, n, k, configs, scale, has_int8_tensor, exclude
)
if config.max_autotune_gemm_search_space == "EXHAUSTIVE":
assert dtype_size > 0, "dtype_size must be provided for exhaustive search"
scaled_configs = self._prune_exhaustive_configs(scaled_configs, dtype_size)
@ -612,10 +594,48 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
def get_exhaustive_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self.preprocess_mm_configs, configs=self.exhaustive_configs)
def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(
self.preprocess_mm_configs, configs=self.conv_configs, op_name="conv"
def get_extra_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self.preprocess_mm_configs, configs=self.extra_mm_configs)
def get_int8_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self.preprocess_mm_configs, configs=self.int8_mm_configs)
def get_mixed_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
mm_configs = (
self.mm_configs + self.mixed_mm_configs
if config.max_autotune_gemm_search_space == "EXHAUSTIVE"
else self.mm_configs
)
return partial(self.preprocess_mm_configs, configs=mm_configs)
def get_persistent_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
persistent_mm_configs = (
self.exhaustive_configs
if config.max_autotune_gemm_search_space == "EXHAUSTIVE"
else self.persistent_mm_configs
)
# num_warps=2 not safe for TMA
persistent_mm_configs = [
config for config in persistent_mm_configs if config.num_warps != 2
]
return partial(self.preprocess_mm_configs, configs=persistent_mm_configs)
def get_scaled_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self.preprocess_mm_configs, configs=self.scaled_mm_configs)
def get_scaled_persistent_mm_configs(
self,
) -> partial[Generator[TritonConfig, None, None]]:
return partial(
self.preprocess_mm_configs, configs=self.scaled_persistent_mm_configs
)
def get_mm_plus_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self._finalize_mm_configs, configs=self.mm_plus_mm_configs)
def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self.preprocess_mm_configs, configs=self.conv_configs)
# Flex attn helpers
def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]:
@ -676,80 +696,7 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
class CPUConfigHeuristic(BaseConfigHeuristic):
"""
CPU-specific config heuristic with CPU-specific optimizations.
"""
def _get_cpu_exclude_function(
self, method: str = "bmm"
) -> Callable[[sympy.Integer, sympy.Integer, sympy.Integer], bool]:
"""
Get CPU-specific exclude function based on method type.
Returns a function that can be used as exclude condition.
Moved from mm_common._is_large_block_for_cpu and refactored to return a function.
"""
if method in ("conv"):
def exclude_conv(
m: sympy.Integer, n: sympy.Integer, k: sympy.Integer
) -> bool:
# Thresholds are experimentally determined to reduce Triton CPU compile times
if m > 256 or n > 256 or k > 256:
return True
return m * n * k > 2**17
return exclude_conv
elif method in ("mm", "addmm", "int_mm"):
def exclude_mm(
m: sympy.Integer, n: sympy.Integer, k: sympy.Integer
) -> bool:
return m * n > 2**13
return exclude_mm
else: # Default to bmm implementation for unknown methods
def exclude_bmm(
m: sympy.Integer, n: sympy.Integer, k: sympy.Integer
) -> bool:
if m > 128 or n > 128 or k > 128:
return True
return m * n > 2**12
return exclude_bmm
def preprocess_mm_configs(
self,
m: int,
n: int,
k: int,
configs: list[BaseConfig],
has_int8_tensor: bool = False,
scale: float = 1.0,
exclude: Callable[
[sympy.Integer, sympy.Integer, sympy.Integer], bool
] = lambda m, n, k: False,
dtype_size: int = 0,
op_name: str = "mm", # For preprocessing overrides e.g. on CPU
) -> Generator[TritonConfig, None, None]:
"""
CPU-specific preprocessing that applies CPU-specific scaling (0.5) and exclusion logic.
"""
# Get CPU-specific exclude function based on operation type
cpu_exclude_fn = self._get_cpu_exclude_function(op_name)
# Apply CPU-specific scaling (0.5) and exclusion logic
return super().preprocess_mm_configs(
m,
n,
k,
configs=configs,
has_int8_tensor=has_int8_tensor,
scale=0.5,
exclude=cpu_exclude_fn,
dtype_size=dtype_size,
op_name=op_name,
)
pass
class CUDAConfigHeuristic(BaseConfigHeuristic):
@ -1055,13 +1002,14 @@ class ROCmConfigHeuristic(BaseConfigHeuristic):
for wpeu in [0, int(8 // num_warps)]
]
def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]:
"""
ROCm specific filtering
"""
def _filter_configs(
self, configs: list[BaseConfig], new_num_stages: int
) -> list[BaseConfig]:
# TODO: _filter_configs can be removed once backend specific configs are added
# for all methods
for c in configs:
c.num_stages = self.default_num_stages
return super()._filter_configs(configs)
return configs
def _finalize_mm_configs(
self,
@ -1128,6 +1076,57 @@ class ROCmConfigHeuristic(BaseConfigHeuristic):
kwargs["GROUP_M"] = group_m
yield self.triton_config(**kwargs)
def get_extra_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.extra_mm_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_int8_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.int8_mm_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_mixed_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
mm_configs = (
self.mm_configs + self.mixed_mm_configs
if config.max_autotune_gemm_search_space == "EXHAUSTIVE"
else self.mm_configs
)
filtered_configs = self._filter_configs(mm_configs, self.default_num_stages)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_persistent_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.persistent_mm_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_scaled_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.scaled_mm_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_scaled_persistent_mm_configs(
self,
) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.scaled_persistent_mm_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_mm_plus_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(self.mm_plus_mm_configs, 1)
return partial(self._finalize_mm_configs, configs=filtered_configs)
def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.conv_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]:
flex_attn_fwd_configs: list[FlexConfig] = []
@ -1208,643 +1207,3 @@ class MTIAConfigHeuristic(BaseConfigHeuristic):
"""
Placeholder child class for MTIA specific overrides.
"""
# Template-specific mixin classes
class TemplateConfigHeuristics:
def get_template_configs(
self,
kernel_inputs: KernelInputs,
layout: Any,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
"""
Get template configs for the given inputs.
This is the main entry point for template-specific logic.
"""
# NOTE: not an abstract class, because that clashed below for the mixin
# functionality. Can be adjusted, but not a high priority
yield from {}
class MMTemplateConfigMixin(TemplateConfigHeuristics):
"""
Mixin class that converts config lists to template kwargs.
This handles the logic that was previously in choices.get_mm_configs.
This mixin expects to be used with BaseConfigHeuristic or its subclasses.
"""
# Type annotations to ensure the mixin works with BaseConfigHeuristic
get_mm_configs: Callable[[], partial[Generator[TritonConfig, None, None]]]
get_exhaustive_mm_configs: Callable[
[], partial[Generator[TritonConfig, None, None]]
]
_filter_configs: Callable[[list[BaseConfig]], list[BaseConfig]]
def _get_config_generator(
self,
) -> partial[Generator[TritonConfig, None, None]]:
"""
Get the appropriate config generator based on search space.
Can be overridden by subclasses for template-specific behavior.
"""
# Handle exhaustive search case
if config.max_autotune_gemm_search_space == "EXHAUSTIVE":
return self.get_exhaustive_mm_configs()
else:
return self.get_mm_configs()
def get_template_configs(
self,
kernel_inputs: KernelInputs,
layout: Any,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
"""
Convert config lists to template kwargs.
This replaces the logic from choices.get_mm_configs and inlines mm_options.
"""
assert isinstance(kernel_inputs, MMKernelInputs), (
f"{self.__class__.__name__} requires MMKernelInputs"
)
input_nodes = kernel_inputs.nodes()
if len(input_nodes) < 2:
raise ValueError(f"Need at least 2 input tensors, got {len(input_nodes)}")
# Extract M, N, K from kernel_inputs
m, n, k = kernel_inputs.mnk_symbolic()
# Extract dtype and device_type from kernel_inputs
dtype = kernel_inputs.dtype()
# Get the appropriate config generator
configs = self._get_config_generator()
# Generate and process configs
for c in configs(m, n, k, dtype_size=dtype.itemsize, op_name=op_name):
template_kwargs = self._convert_config_to_template_kwargs(
c, m, n, k, layout
)
yield template_kwargs
def _convert_config_to_template_kwargs(
self,
triton_config: TritonConfig,
m: sympy.Integer,
n: sympy.Integer,
k: sympy.Integer,
layout: Any,
) -> dict[str, Any]:
"""
Convert triton config to template kwargs.
Moved from mm_common.mm_options.
"""
# Calculate EVEN_K symbolic
even_k_symbolic = (
# it isn't worth guarding on this
sympy.gcd(k, triton_config.kwargs["BLOCK_K"])
== triton_config.kwargs["BLOCK_K"]
)
# Calculate allow_tf32
allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and (
not inductor_config.force_same_precision
or ((m % 16) == 0 and (n % 16) == 0 and (k % 8) == 0)
)
# Build options dict
options_dict = dict(
EVEN_K=even_k_symbolic,
ALLOW_TF32=allow_tf32,
USE_FAST_ACCUM=False, # Option for _scaled_mm
ACC_TYPE=self._get_acc_type(layout.dtype),
num_stages=triton_config.num_stages,
num_warps=triton_config.num_warps,
**triton_config.kwargs,
)
# If GROUP_M not specified then default to 8
if "GROUP_M" not in triton_config.kwargs:
group_m = triton_config.kwargs.get("GROUP_M", 8)
options_dict["GROUP_M"] = group_m
return options_dict
def _get_acc_type(self, dtype: torch.dtype) -> str:
"""
Get accumulator type for the given dtype.
Moved from mm_common.acc_type.
"""
if dtype in (torch.float16, torch.bfloat16):
return "tl.float32"
return f"tl.{dtype}".replace("torch.", "")
# INT8 specific mixin to filter correctly
class INT8MMTemplateConfigMixin(MMTemplateConfigMixin):
"""
Ensure that we feed in has_int8_tensor=True
"""
def __init__(self) -> None:
super().__init__()
self.has_int8_tensor = True
# TMA-specific mixin for TMA templates
class TMAConfigMixin(MMTemplateConfigMixin):
"""
TMA-specific mixin that uses persistent configs and adds TMA options.
This inherits from MMTemplateConfigMixin and overrides config generation.
"""
def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]:
"""
TMA specific filtering, as num_warps=2 not safe for TMA
"""
configs = [c for c in configs if c.num_warps != 2]
return super()._filter_configs(configs)
def get_template_configs(
self,
kernel_inputs: KernelInputs,
layout: Any,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
"""
Generate TMA template configs by calling super and adding TMA-specific options.
"""
# Get base template configs from superclass
for template_kwargs in super().get_template_configs(
kernel_inputs, layout, op_name
):
# Add TMA-specific options (moved from mm_common.persistent_mm_options)
input_nodes = kernel_inputs.nodes()
self._add_tma_options(template_kwargs, input_nodes)
yield template_kwargs
def _add_tma_options(
self, template_kwargs: dict[str, Any], input_nodes: list[Any]
) -> None:
"""
Add TMA-specific options to template kwargs.
Moved from mm_common.persistent_mm_options and mm_common.tma_options.
"""
# For TMA templates, we need the actual matrix tensors
mat1 = input_nodes[-2]
mat2 = input_nodes[-1]
tma_opts = {
"A_ROW_MAJOR": not mat1.layout.is_transposed(),
"B_ROW_MAJOR": not mat2.layout.is_transposed(),
"NUM_SMS": get_num_sms(),
"TMA_SIZE": TMA_DESCRIPTOR_SIZE,
"TMA_EXPERIMENTAL_API": not has_triton_stable_tma_api(),
}
template_kwargs.update(tma_opts)
# Scaled MM-specific mixin for scaled MM templates (non-TMA)
class ScaledMMConfigMixin(MMTemplateConfigMixin):
"""
Scaled MM-specific mixin that uses scaled configs and adds scaled MM options.
This is for non-TMA scaled MM templates only.
This inherits from MMTemplateConfigMixin and overrides config generation.
"""
def get_template_configs(
self,
kernel_inputs: KernelInputs,
layout: Any,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
"""
Generate scaled MM template configs with scaled MM-specific options.
Handles the remaining logic from mm_common including assertions and SCALING_ROWWISE.
"""
input_nodes = kernel_inputs.nodes()
# Initial assertion from mm_common.scaled_mm_options
assert len(input_nodes) >= 4, (
f"scaled_mm requires at least 4 inputs, got {len(input_nodes)}"
)
# Extract scale tensors (typically scale_a and scale_b are input_nodes[2] and input_nodes[3])
scale_a = input_nodes[2]
scale_b = input_nodes[3]
# Scale compatibility assertion from mm_common.scaled_mm_options
def are_compatible_scales(size_a: Any, size_b: Any) -> bool:
# Same sized scales are compatible
if len(size_a) == len(size_b):
return True
# Both need to be scalars or len(1) tensors
if len(size_a) <= 1 and len(size_b) <= 1:
return True
return False
size_a, size_b = scale_a.get_size(), scale_b.get_size()
assert are_compatible_scales(size_a, size_b), (
"Expect scale_a and scale_b to be either both scalars (including single-element tensors) "
f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}."
)
# Get base template configs from superclass
for template_kwargs in super().get_template_configs(
kernel_inputs, layout, op_name
):
# Add scaled MM-specific options (moved from mm_common.scaled_mm_options)
# Override accumulator type for scaled MM
template_kwargs["ACC_TYPE"] = "tl.float32"
# Add SCALING_ROWWISE attribute based on scale_a tensor shape
template_kwargs["SCALING_ROWWISE"] = len(size_a) == 2
yield template_kwargs
# Scaled TMA-specific mixin for scaled MM templates with TMA
class ScaledTMAConfigMixin(ScaledMMConfigMixin):
"""
Scaled TMA-specific mixin that extends ScaledMMConfigMixin with TMA functionality.
This is for scaled MM templates that use device TMA.
This inherits from ScaledMMConfigMixin and adds TMA-specific options.
"""
def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]:
"""
TMA specific filtering, as num_warps=2 not safe for TMA
"""
configs = [c for c in configs if c.num_warps != 2]
return super()._filter_configs(configs)
def get_template_configs(
self,
kernel_inputs: KernelInputs,
layout: Any,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
"""
Generate scaled TMA template configs with both scaled MM and TMA-specific options.
"""
# Get base scaled MM template configs from superclass
for template_kwargs in super().get_template_configs(
kernel_inputs, layout, op_name
):
# Add TMA-specific options for device TMA scaled MM
template_kwargs["TMA_SIZE"] = TMA_DESCRIPTOR_SIZE
template_kwargs["NUM_SMS"] = get_num_sms()
template_kwargs["TMA_EXPERIMENTAL_API"] = not has_triton_stable_tma_api()
yield template_kwargs
# Template-specific heuristic classes using multiple inheritance
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic("mm", "cuda", register=torch.version.hip is None)
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic("bmm", "cuda", register=torch.version.hip is None)
class CUDAMMTemplateConfigHeuristic(MMTemplateConfigMixin, CUDAConfigHeuristic):
"""Standard MM template heuristic for CUDA"""
# TODO(coconutruben): deprecate once autoheuristic is deprecated
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic("mm-ah", "cuda", register=torch.version.hip is None)
class CUDAMMAHTemplateConfigHeuristic(MMTemplateConfigMixin, CUDAConfigHeuristic):
"""Standard MM template heuristic for CUDA using the extra mm configs only (for autoheuristic)"""
def __init__(self) -> None:
super().__init__()
# Override mm_configs to use scaled_mm_configs
self.mm_configs = self.extra_mm_configs
self.exhaustive_configs = self.extra_mm_configs
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic(
"mm_persistent_tma", "cuda", register=torch.version.hip is None
)
class CUDAPersistentTMATemplateConfigHeuristic(TMAConfigMixin, CUDAConfigHeuristic):
"""Persistent TMA template heuristic for CUDA"""
def __init__(self) -> None:
super().__init__()
# Override mm_configs to use persistent_mm_configs
self.mm_configs = self.persistent_mm_configs
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic(
"mm", "cuda", register=torch.version.hip is None, op_name="scaled_mm"
)
class CUDAScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, CUDAConfigHeuristic):
"""Scaled MM template heuristic for CUDA"""
def __init__(self) -> None:
super().__init__()
# Override mm_configs to use scaled_mm_configs
self.mm_configs = self.scaled_mm_configs
# NOTE: overriding exhaustive configs here to be the same as mm_configs
# as we haven't validated exhaustive support here yet
# TODO(coconutruben): remove this once we have validated exhaustive support
# for scaled_mm
self.exhaustive_configs = self.scaled_mm_configs
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic(
"scaled_mm_device_tma", "cuda", register=torch.version.hip is None
)
class CUDAScaledTMATemplateConfigHeuristic(ScaledTMAConfigMixin, CUDAConfigHeuristic):
"""Scaled TMA template heuristic for CUDA"""
def __init__(self) -> None:
super().__init__()
# Override mm_configs to use scaled_persistent_mm_configs for TMA
self.mm_configs = self.scaled_persistent_mm_configs
# NOTE: overriding exhaustive configs here to be the same as mm_configs
# as we haven't validated exhaustive support here yet
# TODO(coconutruben): remove this once we have validated exhaustive support
# for scaled_mm
self.exhaustive_configs = self.scaled_persistent_mm_configs
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic("mm_plus_mm", "cuda", register=torch.version.hip is None)
class CUDAMMPlusMMTemplateConfigHeuristic(MMTemplateConfigMixin, CUDAConfigHeuristic):
"""MM Plus MM template heuristic for CUDA"""
def __init__(self) -> None:
super().__init__()
# Override mm_configs to use mm_plus_mm_configs
self.mm_configs = self.mm_plus_mm_configs
# NOTE: overriding exhaustive configs here to be the same as mm_configs
# as we haven't validated exhaustive support here yet
# TODO(coconutruben): remove this once we have validated exhaustive support
# for scaled_mm
self.exhaustive_configs = self.mm_plus_mm_configs
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic(
"mm", "cuda", register=torch.version.hip is None, op_name="int_mm"
)
class CUDAInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, CUDAConfigHeuristic):
"""Int8 MM template heuristic for CUDA"""
def __init__(self) -> None:
super().__init__()
# Override mm_configs to use int8_mm_configs
self.mm_configs = self.int8_mm_configs
# NOTE: overriding exhaustive configs here to be the same as mm_configs
# as we haven't validated exhaustive support here yet
# TODO(coconutruben): remove this once we have validated exhaustive support
# for scaled_mm
self.exhaustive_configs = self.int8_mm_configs
# ROCm template-specific classes
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic("mm", "cuda", register=torch.version.hip is not None)
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic("bmm", "cuda", register=torch.version.hip is not None)
class ROCmMMTemplateConfigHeuristic(MMTemplateConfigMixin, ROCmConfigHeuristic):
"""Standard MM template heuristic for ROCm"""
# TODO(coconutruben): deprecate once autoheuristic is deprecated
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic("mm-ah", "cuda", register=torch.version.hip is not None)
class ROCmMMAHTemplateConfigHeuristic(MMTemplateConfigMixin, ROCmConfigHeuristic):
"""Standard MM template heuristic for ROCm using the extra mm configs only (for autoheuristic)"""
def __init__(self) -> None:
super().__init__()
# Override mm_configs to use scaled_mm_configs
self.mm_configs = self.extra_mm_configs
self.exhaustive_configs = self.extra_mm_configs
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic(
"mm", "cuda", register=torch.version.hip is not None, op_name="scaled_mm"
)
class ROCmScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, ROCmConfigHeuristic):
"""Scaled MM template heuristic for ROCm (non-TMA)"""
def __init__(self) -> None:
super().__init__()
# Override mm_configs to use scaled_mm_configs
self.mm_configs = self.scaled_mm_configs
# NOTE: overriding exhaustive configs here to be the same as mm_configs
# as we haven't validated exhaustive support here yet
# TODO(coconutruben): remove this once we have validated exhaustive support
# for scaled_mm
self.exhaustive_configs = self.scaled_mm_configs
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic(
"mm", "cuda", register=torch.version.hip is not None, op_name="int_mm"
)
class ROCmInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, ROCmConfigHeuristic):
"""Int8 MM template heuristic for ROCm"""
def __init__(self) -> None:
super().__init__()
# Override mm_configs to use int8_mm_configs
self.mm_configs = self.int8_mm_configs
# NOTE: overriding exhaustive configs here to be the same as mm_configs
# as we haven't validated exhaustive support here yet
# TODO(coconutruben): remove this once we have validated exhaustive support
# for scaled_mm
self.exhaustive_configs = self.int8_mm_configs
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic(
"mm_plus_mm", "cuda", register=torch.version.hip is not None
)
class ROCmMMPlusMMTemplateConfigHeuristic(MMTemplateConfigMixin, ROCmConfigHeuristic):
"""MM Plus MM template heuristic for ROCm"""
def __init__(self) -> None:
super().__init__()
# Override mm_configs to use mm_plus_mm_configs
self.mm_configs = self.mm_plus_mm_configs
# NOTE: overriding exhaustive configs here to be the same as mm_configs
# as we haven't validated exhaustive support here yet
# TODO(coconutruben): remove this once we have validated exhaustive support
# for scaled_mm
self.exhaustive_configs = self.mm_plus_mm_configs
# CPU template-specific classes
@register_template_heuristic("mm", "cpu")
@register_template_heuristic("bmm", "cpu")
class CPUMMTemplateConfigHeuristic(MMTemplateConfigMixin, CPUConfigHeuristic):
"""Standard MM template heuristic for CPU"""
@register_template_heuristic("mm", "cpu", op_name="scaled_mm")
class CPUScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, CPUConfigHeuristic):
"""Scaled MM template heuristic for CPU (non-TMA)"""
def __init__(self) -> None:
super().__init__()
# Override mm_configs to use scaled_mm_configs
self.mm_configs = self.scaled_mm_configs
# NOTE: overriding exhaustive configs here to be the same as mm_configs
# as we haven't validated exhaustive support here yet
# TODO(coconutruben): remove this once we have validated exhaustive support
# for scaled_mm
self.exhaustive_configs = self.scaled_mm_configs
@register_template_heuristic("mm", "cpu", op_name="int_mm")
class CPUInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, CPUConfigHeuristic):
"""Int8 MM template heuristic for CPU"""
def __init__(self) -> None:
super().__init__()
# Override mm_configs to use int8_mm_configs
self.mm_configs = self.int8_mm_configs
# NOTE: overriding exhaustive configs here to be the same as mm_configs
# as we haven't validated exhaustive support here yet
# TODO(coconutruben): remove this once we have validated exhaustive support
# for scaled_mm
self.exhaustive_configs = self.int8_mm_configs
@register_template_heuristic("mm_plus_mm", "cpu")
class CPUMMPlusMMTemplateConfigHeuristic(MMTemplateConfigMixin, CPUConfigHeuristic):
"""MM Plus MM template heuristic for CPU"""
def __init__(self) -> None:
super().__init__()
# Override mm_configs to use mm_plus_mm_configs
self.mm_configs = self.mm_plus_mm_configs
# NOTE: overriding exhaustive configs here to be the same as mm_configs
# as we haven't validated exhaustive support here yet
# TODO(coconutruben): remove this once we have validated exhaustive support
# for scaled_mm
self.exhaustive_configs = self.mm_plus_mm_configs
# XPU template-specific classes
@register_template_heuristic("mm", "xpu")
@register_template_heuristic("bmm", "xpu")
class XPUMMTemplateConfigHeuristic(MMTemplateConfigMixin, XPUConfigHeuristic):
"""Standard MM template heuristic for XPU"""
@register_template_heuristic("mm", "xpu", op_name="scaled_mm")
class XPUScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, XPUConfigHeuristic):
"""Scaled MM template heuristic for XPU (non-TMA)"""
def __init__(self) -> None:
super().__init__()
# Override mm_configs to use scaled_mm_configs
self.mm_configs = self.scaled_mm_configs
# NOTE: overriding exhaustive configs here to be the same as mm_configs
# as we haven't validated exhaustive support here yet
# TODO(coconutruben): remove this once we have validated exhaustive support
# for scaled_mm
self.exhaustive_configs = self.scaled_mm_configs
@register_template_heuristic("mm", "xpu", op_name="int_mm")
class XPUInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, XPUConfigHeuristic):
"""Int8 MM template heuristic for XPU"""
def __init__(self) -> None:
super().__init__()
# Override mm_configs to use int8_mm_configs
self.mm_configs = self.int8_mm_configs
# NOTE: overriding exhaustive configs here to be the same as mm_configs
# as we haven't validated exhaustive support here yet
# TODO(coconutruben): remove this once we have validated exhaustive support
# for scaled_mm
self.exhaustive_configs = self.int8_mm_configs
@register_template_heuristic("mm_plus_mm", "xpu")
class XPUMMPlusMMTemplateConfigHeuristic(MMTemplateConfigMixin, XPUConfigHeuristic):
"""MM Plus MM template heuristic for XPU"""
def __init__(self) -> None:
super().__init__()
# Override mm_configs to use mm_plus_mm_configs
self.mm_configs = self.mm_plus_mm_configs
# NOTE: overriding exhaustive configs here to be the same as mm_configs
# as we haven't validated exhaustive support here yet
# TODO(coconutruben): remove this once we have validated exhaustive support
# for scaled_mm
self.exhaustive_configs = self.mm_plus_mm_configs
# MTIA template-specific classes
@register_template_heuristic("mm", "mtia")
@register_template_heuristic("bmm", "mtia")
class MTIAMMTemplateConfigHeuristic(MMTemplateConfigMixin, MTIAConfigHeuristic):
"""Standard MM template heuristic for MTIA"""
@register_template_heuristic("mm", "mtia", op_name="scaled_mm")
class MTIAScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, MTIAConfigHeuristic):
"""Scaled MM template heuristic for MTIA (non-TMA)"""
def __init__(self) -> None:
super().__init__()
# Override mm_configs to use scaled_mm_configs
self.mm_configs = self.scaled_mm_configs
# NOTE: overriding exhaustive configs here to be the same as mm_configs
# as we haven't validated exhaustive support here yet
# TODO(coconutruben): remove this once we have validated exhaustive support
# for scaled_mm
self.exhaustive_configs = self.scaled_mm_configs
@register_template_heuristic("mm", "mtia", op_name="int_mm")
class MTIAInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, MTIAConfigHeuristic):
"""Int8 MM template heuristic for MTIA"""
def __init__(self) -> None:
super().__init__()
# Override mm_configs to use int8_mm_configs
self.mm_configs = self.int8_mm_configs
# NOTE: overriding exhaustive configs here to be the same as mm_configs
# as we haven't validated exhaustive support here yet
# TODO(coconutruben): remove this once we have validated exhaustive support
# for scaled_mm
self.exhaustive_configs = self.int8_mm_configs
@register_template_heuristic("mm_plus_mm", "mtia")
class MTIAMMPlusMMTemplateConfigHeuristic(MMTemplateConfigMixin, MTIAConfigHeuristic):
"""MM Plus MM template heuristic for MTIA"""
def __init__(self) -> None:
super().__init__()
# Override mm_configs to use mm_plus_mm_configs
self.mm_configs = self.mm_plus_mm_configs
# NOTE: overriding exhaustive configs here to be the same as mm_configs
# as we haven't validated exhaustive support here yet
# TODO(coconutruben): remove this once we have validated exhaustive support
# for scaled_mm
self.exhaustive_configs = self.mm_plus_mm_configs

View File

@ -1,98 +0,0 @@
"""
Template heuristic registry system for PyTorch Inductor.
This module provides a centralized registration system for template heuristics,
allowing automatic registration based on device type and conditional registration
for CUDA vs ROCm based on torch.version.hip.
"""
from __future__ import annotations
import logging
from functools import cache
from typing import Any, Optional, TYPE_CHECKING
if TYPE_CHECKING:
from .template_heuristics import TemplateConfigHeuristics
# Module-wide registry for template heuristics
_TEMPLATE_HEURISTIC_REGISTRY: dict[tuple[str, ...], type[TemplateConfigHeuristics]] = {}
log = logging.getLogger(__name__)
def register_template_heuristic(
template_name: str,
device_type: str,
register: bool = True,
op_name: Optional[str] = None,
) -> Any:
"""
Decorator to register template heuristic classes.
Args:
template_name: Name of the template (e.g., "mm", "bmm", "scaled_mm")
device_type: Device type ("cuda", "cpu", "xpu")
register: Whether to register this heuristic. Caller should pass the condition directly.
op_name: Name of the operator (e.g., "mm", "bmm", "scaled_mm"). This is optional
and is only used when a template uses different heuristics for different ops
Returns:
Decorator function that registers the class if conditions are met.
Example:
@register_template_heuristic("mm", "cuda", register=torch.version.hip is None)
class CUDAMMTemplateConfigHeuristic(MMTemplateConfigMixin, CUDAConfigHeuristic):
pass
"""
def decorator(
cls: type[TemplateConfigHeuristics],
) -> type[TemplateConfigHeuristics]:
if register:
key: tuple[str, ...] = (device_type, template_name)
if op_name is not None:
key = (device_type, template_name, op_name)
_TEMPLATE_HEURISTIC_REGISTRY[key] = cls
log.info(
f"Registered template heuristic: {cls.__name__} for '{template_name=}', '{device_type=}', '{op_name=}'" # noqa: G004
)
return cls
return decorator
@cache
def get_template_heuristic(
template_name: str, device_type: str, op_name: str
) -> TemplateConfigHeuristics:
"""
Retrieve a template heuristic instance for the given template and device type.
Args:
template_name: Name of the template (e.g., "mm", "bmm", "scaled_mm")
device_type: Device type ("cuda", "cpu", "xpu")
Returns:
Template heuristic instance.
Raises:
ValueError: If no heuristic is found for the given combination.
"""
# First check the more specific key
keys = [(device_type, template_name, op_name), (device_type, template_name)]
# Look up in registry
heuristic_class = None
for key in keys:
if key in _TEMPLATE_HEURISTIC_REGISTRY:
heuristic_class = _TEMPLATE_HEURISTIC_REGISTRY[key]
break
if heuristic_class is None:
raise ValueError(
f"No template heuristic found for '{template_name=}', "
f"'{device_type=}', '{op_name=}'. "
f"Available combinations: {list(_TEMPLATE_HEURISTIC_REGISTRY.keys())}"
)
return heuristic_class()