mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Revert "[inductor] consolidate common GEMM triton param retrieval (#159383)"
This reverts commit e7cc42df58.
Reverted https://github.com/pytorch/pytorch/pull/159383 on behalf of https://github.com/jataylo due to sorry but rocm CI is broken due to this PR ([comment](https://github.com/pytorch/pytorch/pull/159383#issuecomment-3145604831))
This commit is contained in:
parent
c687446374
commit
acad808545
|
|
@ -35,10 +35,7 @@ from torch._inductor.select_algorithm import (
|
|||
TritonTemplate,
|
||||
TritonTemplateCaller,
|
||||
)
|
||||
from torch._inductor.template_heuristics import (
|
||||
CUDAMMTemplateConfigHeuristic,
|
||||
GemmConfig,
|
||||
)
|
||||
from torch._inductor.template_heuristics import CUDAConfigHeuristic, GemmConfig
|
||||
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
|
|
@ -1176,7 +1173,7 @@ class TestMaxAutotune(TestCase):
|
|||
# Force only decomposeK choice
|
||||
with (
|
||||
mock.patch(
|
||||
"torch._inductor.kernel.mm.V.choices.get_mm_configs"
|
||||
"torch._inductor.kernel.mm.V.choices.get_base_mm_configs"
|
||||
) as base_mm_mock,
|
||||
mock.patch(
|
||||
"torch._inductor.kernel.mm.use_decompose_k_choice"
|
||||
|
|
@ -1564,9 +1561,9 @@ class TestMaxAutotune(TestCase):
|
|||
b = torch.randn(K, N, dtype=torch.float16, device="cuda", requires_grad=True)
|
||||
|
||||
with mock.patch(
|
||||
"torch._inductor.template_registry.get_template_heuristic"
|
||||
"torch._inductor.kernel.mm.V.choices.get_config_heuristics"
|
||||
) as config_mock:
|
||||
config_heuristics = CUDAMMTemplateConfigHeuristic()
|
||||
config_heuristics = CUDAConfigHeuristic()
|
||||
|
||||
# Traditionally, this would be set of all possible configs
|
||||
# We mock out the code path for the sake of the unit test
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ import torch
|
|||
|
||||
from . import config
|
||||
from .codecache import write_text
|
||||
from .kernel_inputs import KernelInputs # noqa: TC001
|
||||
from .metrics import get_metric_table, is_metric_table_enabled
|
||||
from .runtime.hints import DeviceProperties, ReductionHint
|
||||
from .scheduler import BaseSchedulerNode, Scheduler, WhyNoFuse
|
||||
|
|
@ -21,7 +20,6 @@ from .template_heuristics import (
|
|||
ROCmConfigHeuristic,
|
||||
XPUConfigHeuristic,
|
||||
)
|
||||
from .template_registry import get_template_heuristic
|
||||
from .virtualized import V
|
||||
|
||||
|
||||
|
|
@ -73,6 +71,58 @@ class InductorChoices:
|
|||
else:
|
||||
return BaseConfigHeuristic()
|
||||
|
||||
# GEMM configs
|
||||
def get_base_mm_configs(
|
||||
self, device_type: Optional[str] = "cuda"
|
||||
) -> partial[Generator[TritonConfig, None, None]]:
|
||||
mm_heuristics = self.get_config_heuristics(device_type)
|
||||
if config.max_autotune_gemm_search_space != "EXHAUSTIVE":
|
||||
return mm_heuristics.get_mm_configs()
|
||||
else:
|
||||
return mm_heuristics.get_exhaustive_mm_configs()
|
||||
|
||||
def get_extra_mm_configs(
|
||||
self, device_type: Optional[str] = "cuda"
|
||||
) -> partial[Generator[TritonConfig, None, None]]:
|
||||
mm_heuristics = self.get_config_heuristics(device_type)
|
||||
return mm_heuristics.get_extra_mm_configs()
|
||||
|
||||
def get_int8_mm_configs(
|
||||
self, device_type: Optional[str] = "cuda"
|
||||
) -> partial[Generator[TritonConfig, None, None]]:
|
||||
mm_heuristics = self.get_config_heuristics(device_type)
|
||||
return mm_heuristics.get_int8_mm_configs()
|
||||
|
||||
def get_mixed_mm_configs(
|
||||
self, device_type: Optional[str] = "cuda"
|
||||
) -> partial[Generator[TritonConfig, None, None]]:
|
||||
mm_heuristics = self.get_config_heuristics(device_type)
|
||||
return mm_heuristics.get_mixed_mm_configs()
|
||||
|
||||
def get_persistent_mm_configs(
|
||||
self, device_type: Optional[str] = "cuda"
|
||||
) -> partial[Generator[TritonConfig, None, None]]:
|
||||
mm_heuristics = self.get_config_heuristics(device_type)
|
||||
return mm_heuristics.get_persistent_mm_configs()
|
||||
|
||||
def get_scaled_mm_configs(
|
||||
self, device_type: Optional[str] = "cuda"
|
||||
) -> partial[Generator[TritonConfig, None, None]]:
|
||||
mm_heuristics = self.get_config_heuristics(device_type)
|
||||
return mm_heuristics.get_scaled_mm_configs()
|
||||
|
||||
def get_scaled_persistent_mm_configs(
|
||||
self, device_type: Optional[str] = "cuda"
|
||||
) -> partial[Generator[TritonConfig, None, None]]:
|
||||
mm_heuristics = self.get_config_heuristics(device_type)
|
||||
return mm_heuristics.get_scaled_persistent_mm_configs()
|
||||
|
||||
def get_mm_plus_mm_configs(
|
||||
self, device_type: Optional[str] = "cuda"
|
||||
) -> partial[Generator[TritonConfig, None, None]]:
|
||||
mm_heuristics = self.get_config_heuristics(device_type)
|
||||
return mm_heuristics.get_mm_plus_mm_configs()
|
||||
|
||||
# Conv configs
|
||||
def get_conv_configs(
|
||||
self, device_type: Optional[str] = "cuda"
|
||||
|
|
@ -81,7 +131,6 @@ class InductorChoices:
|
|||
return conv_heuristics.get_conv_configs()
|
||||
|
||||
# Flex attention configs
|
||||
# TODO(coconutruben): break out flexattention/decode configs into the new retrieval mechanism
|
||||
def get_flex_attention_fwd_configs(
|
||||
self, head_dim: int, dtype: torch.dtype, device_type: Optional[str] = "cuda"
|
||||
) -> list[Any]:
|
||||
|
|
@ -100,37 +149,6 @@ class InductorChoices:
|
|||
flex_heuristics = self.get_config_heuristics(device_type)
|
||||
return flex_heuristics.get_flex_decode_configs(head_dim, dtype)
|
||||
|
||||
def get_mm_configs(
|
||||
self,
|
||||
kernel_inputs: KernelInputs,
|
||||
layout: Any,
|
||||
template_name: str,
|
||||
op_name: str,
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
"""
|
||||
Get generator of template parameters for MM templates using template-specific heuristics.
|
||||
|
||||
Args:
|
||||
kernel_inputs: MMKernelInputs containing input tensor nodes and matrix indices
|
||||
layout: Output layout
|
||||
template_name: Template name (e.g., "bmm", "mm", "mm_persistent_tma")
|
||||
op_name: Operation name (e.g., "bmm", "baddbmm", "addmm", "mm_plus_mm")
|
||||
|
||||
Yields:
|
||||
Template parameter dictionaries ready for maybe_append_choice
|
||||
"""
|
||||
input_tensors = kernel_inputs.nodes()
|
||||
if len(input_tensors) < 2:
|
||||
raise ValueError(f"Need at least 2 input tensors, got {len(input_tensors)}")
|
||||
|
||||
# Extract device_type from kernel_inputs
|
||||
device_type = kernel_inputs.device_type
|
||||
assert device_type is not None, "get_mm_configs requires a valid device type"
|
||||
# Get the appropriate template-specific heuristic
|
||||
heuristic = get_template_heuristic(template_name, device_type, op_name)
|
||||
|
||||
yield from heuristic.get_template_configs(kernel_inputs, layout, op_name)
|
||||
|
||||
def triton_kernel_kwargs(
|
||||
self,
|
||||
kernel_cls: type[TritonKernel],
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from torch._dynamo.utils import counters
|
|||
from torch._inductor.codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
|
||||
|
||||
from .. import ir, lowering as L
|
||||
from ..kernel_inputs import MMKernelInputs
|
||||
from ..select_algorithm import (
|
||||
autotune_select_algorithm,
|
||||
ExternKernelChoice,
|
||||
|
|
@ -27,6 +26,8 @@ from .mm_common import (
|
|||
addmm_epilogue,
|
||||
is_batch_stride_largest,
|
||||
mm_args,
|
||||
mm_config_kwargs,
|
||||
mm_options,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -39,6 +40,13 @@ def bmm_grid(b, m, n, meta, *, cdiv):
|
|||
return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), b, 1)
|
||||
|
||||
|
||||
def _is_large_block_for_cpu(m, n, k):
|
||||
# Thresholds are experimentally determined to reduce Triton CPU compile times
|
||||
if m > 128 or n > 128 or k > 128:
|
||||
return True
|
||||
return m * n > 2**12
|
||||
|
||||
|
||||
bmm_template = TritonTemplate(
|
||||
name="bmm",
|
||||
grid=bmm_grid,
|
||||
|
|
@ -167,14 +175,9 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
|
|||
meta_mat2 = V.graph.current_node.args[1]
|
||||
mat2 = may_require_contiguous(mat2, meta_mat2)
|
||||
|
||||
# TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
|
||||
m, n, k, layout, mat1, mat2 = mm_args(
|
||||
mat1, mat2, layout=layout, out_dtype=out_dtype
|
||||
)
|
||||
name = "bmm"
|
||||
|
||||
# Create MMKernelInputs for BMM at the top
|
||||
kernel_inputs = MMKernelInputs([mat1, mat2])
|
||||
|
||||
# below is for getting an overview logging info of inductor mms
|
||||
batch_size = mat1.get_size()[0] # Extract batch dimension
|
||||
|
|
@ -192,27 +195,31 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
|
|||
|
||||
if out_dtype:
|
||||
assert mat1.get_device().type == "cuda", "out_dtype is only supported for CUDA"
|
||||
aten_func = aten_bmm_dtype.bind(
|
||||
kernel_inputs.nodes(), layout, out_dtype=out_dtype
|
||||
)
|
||||
aten_func = aten_bmm_dtype.bind((mat1, mat2), layout, out_dtype=out_dtype)
|
||||
else:
|
||||
aten_func = aten_bmm.bind(kernel_inputs.nodes(), layout)
|
||||
aten_func = aten_bmm.bind((mat1, mat2), layout)
|
||||
|
||||
# options to tune from
|
||||
choices = [aten_func] if use_aten_gemm_kernels() else []
|
||||
|
||||
device_type = ir.get_device_type(mat1)
|
||||
bmm_configs = V.choices.get_base_mm_configs(device_type)
|
||||
|
||||
dtype = mat1.get_dtype()
|
||||
if use_triton_template(layout):
|
||||
# TODO: add out_dtype support for Triton Template
|
||||
assert out_dtype is None, "out_dtype is not supported for Triton"
|
||||
|
||||
for kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs, layout, bmm_template.name, name
|
||||
for config in bmm_configs(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
**mm_config_kwargs(device_type, _is_large_block_for_cpu, dtype.itemsize),
|
||||
):
|
||||
bmm_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=kernel_inputs.nodes(),
|
||||
input_nodes=(mat1, mat2),
|
||||
layout=layout,
|
||||
**kwargs,
|
||||
**mm_options(config, m, n, k, layout),
|
||||
)
|
||||
_, is_nonzero = _is_static_problem(layout)
|
||||
batch_stride_largest = is_batch_stride_largest(mat1, mat2, layout)
|
||||
|
|
@ -220,13 +227,11 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
|
|||
batch_stride_largest
|
||||
and is_nonzero
|
||||
and use_cutlass_template(layout, m, n, k)
|
||||
and _use_cutlass_for_op(name)
|
||||
and _use_cutlass_for_op("bmm")
|
||||
):
|
||||
from ..codegen.cuda.gemm_template import CUTLASS3xGemmTemplate
|
||||
|
||||
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
|
||||
choices, layout, kernel_inputs.nodes()
|
||||
) # type: ignore[arg-type]
|
||||
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2]) # type: ignore[arg-type]
|
||||
|
||||
if use_cpp_bmm_template(layout, mat1, mat2):
|
||||
from ..codegen.cpp_bmm_template import CppBmmTemplate
|
||||
|
|
@ -234,23 +239,19 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
|
|||
CppBmmTemplate.add_choices(
|
||||
choices,
|
||||
layout,
|
||||
kernel_inputs.nodes(),
|
||||
[mat1, mat2],
|
||||
)
|
||||
|
||||
if use_ck_gemm_template(layout, m, n, k):
|
||||
CKGemmTemplate.add_ck_gemm_choices(choices, layout, kernel_inputs.nodes())
|
||||
CKGemmTemplate.add_ck_gemm_choices(choices, layout, [mat1, mat2])
|
||||
|
||||
return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout)
|
||||
return autotune_select_algorithm("bmm", choices, [mat1, mat2], layout)
|
||||
|
||||
|
||||
@L.register_lowering(aten.baddbmm)
|
||||
def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
||||
# TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
|
||||
m, n, k, layout, mat1, mat2, inp = mm_args(mat1, mat2, inp, layout=layout)
|
||||
|
||||
# Create MMKernelInputs for BadDBMM at the top
|
||||
kernel_inputs = MMKernelInputs([inp, mat1, mat2])
|
||||
|
||||
# below is for getting an overview logging info of inductor mms
|
||||
batch_size = mat1.get_size()[0]
|
||||
counters["aten_mm_info"][f"aten.baddbmm_{batch_size}_{m}_{n}_{k}"] += 1
|
||||
|
|
@ -265,26 +266,29 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
|||
inp.get_dtype(),
|
||||
layout,
|
||||
)
|
||||
name = "baddbmm"
|
||||
|
||||
# options to tune from
|
||||
choices = (
|
||||
[aten_baddbmm.bind(kernel_inputs.nodes(), layout, alpha=alpha, beta=beta)]
|
||||
[aten_baddbmm.bind((inp, mat1, mat2), layout, alpha=alpha, beta=beta)]
|
||||
if use_aten_gemm_kernels()
|
||||
else []
|
||||
)
|
||||
|
||||
device_type = ir.get_device_type(mat1)
|
||||
bmm_configs = V.choices.get_base_mm_configs(device_type)
|
||||
|
||||
if use_triton_template(layout):
|
||||
for kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs, layout, bmm_template.name, name
|
||||
for config in bmm_configs(
|
||||
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
|
||||
):
|
||||
bmm_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=kernel_inputs.nodes(),
|
||||
input_nodes=(inp, mat1, mat2),
|
||||
layout=layout,
|
||||
**kwargs,
|
||||
**mm_options(config, m, n, k, layout),
|
||||
prefix_args=1,
|
||||
epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
|
||||
epilogue_fn_hash=str(["addmm_epilogue", layout.dtype, alpha, beta]),
|
||||
)
|
||||
|
||||
return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout)
|
||||
return autotune_select_algorithm("baddbmm", choices, [inp, mat1, mat2], layout)
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from ..utils import (
|
|||
use_triton_template,
|
||||
)
|
||||
from ..virtualized import V
|
||||
from .mm_common import mm_config_kwargs
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -60,6 +61,13 @@ def conv3d_grid(n, c, d, h, w, meta, *, cdiv):
|
|||
)
|
||||
|
||||
|
||||
def _is_large_block_for_cpu(m, n, k):
|
||||
# Thresholds are experimentally determined to reduce Triton CPU compile times
|
||||
if m > 256 or n > 256 or k > 256:
|
||||
return True
|
||||
return m * n * k > 2**17
|
||||
|
||||
|
||||
LOOP_BODY_2D = """
|
||||
idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H
|
||||
idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W
|
||||
|
|
@ -595,6 +603,7 @@ def convolution(
|
|||
sympy_product([x.get_size()[0], *x.get_size()[2:]]),
|
||||
out_chan,
|
||||
in_chan,
|
||||
**mm_config_kwargs(device_type, _is_large_block_for_cpu),
|
||||
):
|
||||
if ndim == 2:
|
||||
conv2d_template.maybe_append_choice(
|
||||
|
|
|
|||
|
|
@ -19,13 +19,12 @@ from torch._inductor.virtualized import V
|
|||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.torch_version import TorchVersion
|
||||
|
||||
from .. import config as inductor_config
|
||||
from .. import config as inductor_config, ir
|
||||
from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate
|
||||
from ..codegen.rocm.ck_tile_universal_gemm_template import CKTileGemmTemplate
|
||||
from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
|
||||
from ..codegen.subgraph import SubgraphTemplate
|
||||
from ..ir import FlexibleLayout, is_triton
|
||||
from ..kernel_inputs import MMKernelInputs
|
||||
from ..lowering import (
|
||||
add_layout_constraint,
|
||||
constrain_to_fx_strides,
|
||||
|
|
@ -55,9 +54,13 @@ from .mm_common import (
|
|||
_is_static_problem,
|
||||
addmm_epilogue,
|
||||
mm_args,
|
||||
mm_config_kwargs,
|
||||
mm_grid,
|
||||
mm_options,
|
||||
persistent_mm_grid,
|
||||
persistent_mm_options,
|
||||
scale_mm_epilogue,
|
||||
scaled_mm_options,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -584,6 +587,11 @@ def _is_int8_mat(mat):
|
|||
return mat.get_dtype() in (torch.int8, torch.uint8)
|
||||
|
||||
|
||||
def _is_large_block_for_cpu(m, n, k):
|
||||
# Thresholds are experimentally determined to reduce Triton CPU compile times
|
||||
return m * n > 2**13
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def using_b200() -> bool:
|
||||
"""Returns true if the device is a NVIDIA B200, otherwise returns false."""
|
||||
|
|
@ -653,14 +661,10 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
|||
"""
|
||||
Lowering for autotuning aten.mm with different backends (Aten, Triton, CUTLASS, etc.)
|
||||
"""
|
||||
# TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
|
||||
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
|
||||
static_shape, is_nonzero = _is_static_problem(layout)
|
||||
device_type = ir.get_device_type(mat1)
|
||||
name = "mm"
|
||||
|
||||
# Create MMKernelInputs for standard MM at the top
|
||||
kernel_inputs = MMKernelInputs([mat1, mat2])
|
||||
|
||||
# below is for getting an overview logging info of inductor mms
|
||||
counters["aten_mm_info"][f"aten.mm_{m}_{n}_{k}"] += 1
|
||||
log.info(
|
||||
|
|
@ -681,38 +685,48 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
|||
|
||||
# options to tune from
|
||||
choices = (
|
||||
[aten_mm.bind(kernel_inputs.nodes(), aten_layout)]
|
||||
if use_aten_gemm_kernels()
|
||||
else []
|
||||
[aten_mm.bind((mat1, mat2), aten_layout)] if use_aten_gemm_kernels() else []
|
||||
)
|
||||
static_shape, is_nonzero = _is_static_problem(layout)
|
||||
|
||||
mm_configs = V.choices.get_base_mm_configs(device_type)
|
||||
persistent_mm_configs = V.choices.get_persistent_mm_configs(device_type)
|
||||
extra_mm_configs = V.choices.get_extra_mm_configs(device_type)
|
||||
|
||||
dtype = mat1.get_dtype()
|
||||
if is_nonzero and use_triton_template(layout):
|
||||
# Get template params using the new unified function
|
||||
for kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs, layout, mm_template.name, "mm"
|
||||
for config in mm_configs(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
**mm_config_kwargs(device_type, _is_large_block_for_cpu, dtype.itemsize),
|
||||
):
|
||||
mm_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=kernel_inputs.nodes(),
|
||||
input_nodes=(mat1, mat2),
|
||||
layout=layout,
|
||||
**kwargs,
|
||||
**mm_options(config, m, n, k, layout),
|
||||
)
|
||||
|
||||
if use_triton_tma_template(mat1, mat2):
|
||||
# Get TMA template params using the new unified function
|
||||
for kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs, layout, persistent_tma_mm_template.name, "mm"
|
||||
for config in persistent_mm_configs(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
**mm_config_kwargs(
|
||||
device_type, _is_large_block_for_cpu, dtype.itemsize
|
||||
),
|
||||
):
|
||||
persistent_tma_mm_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=kernel_inputs.nodes(),
|
||||
input_nodes=(mat1, mat2),
|
||||
layout=layout,
|
||||
workspace_arg=get_tma_workspace_arg(
|
||||
num_tma_descriptors=2,
|
||||
device=mat1.get_device(),
|
||||
),
|
||||
**kwargs,
|
||||
**mm_options(config, m, n, k, layout),
|
||||
**persistent_mm_options(mat1, mat2),
|
||||
)
|
||||
|
||||
from torch._inductor.ir import get_free_symbols
|
||||
|
|
@ -762,20 +776,18 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
|||
and use_cutlass_template(layout, m, n, k)
|
||||
and _use_cutlass_for_op("mm")
|
||||
):
|
||||
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
|
||||
choices, layout, kernel_inputs.nodes()
|
||||
)
|
||||
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2])
|
||||
|
||||
if is_nonzero and use_ck_gemm_template(layout, m, n, k):
|
||||
CKGemmTemplate.add_ck_gemm_choices(choices, layout, kernel_inputs.nodes())
|
||||
CKGemmTemplate.add_ck_gemm_choices(choices, layout, [mat1, mat2])
|
||||
if is_nonzero and use_ck_tile_gemm_template(layout, m, n, k):
|
||||
CKTileGemmTemplate.add_choices(choices, layout, kernel_inputs.nodes())
|
||||
CKTileGemmTemplate.add_choices(choices, layout, [mat1, mat2])
|
||||
|
||||
if use_cpp_gemm_template(layout, mat1, mat2):
|
||||
CppGemmTemplate.add_choices(
|
||||
choices,
|
||||
layout,
|
||||
kernel_inputs.nodes(),
|
||||
[mat1, mat2],
|
||||
)
|
||||
|
||||
input_nodes = [mat1, mat2]
|
||||
|
|
@ -789,20 +801,14 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
|||
if use_aten_gemm_kernels():
|
||||
always_included.append("extern_mm")
|
||||
num_choices_before_extra_configs = len(choices)
|
||||
for kwargs in V.choices.get_mm_configs(
|
||||
# TODO(coconutruben): remove once we deprecate ah
|
||||
# mm-extra is a hack to keep the ah functionality alive
|
||||
# while we transition to the unified kwargs retrieval
|
||||
kernel_inputs,
|
||||
layout,
|
||||
"mm-ah",
|
||||
"mm",
|
||||
for config in extra_mm_configs(
|
||||
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
|
||||
):
|
||||
mm_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=kernel_inputs.nodes(),
|
||||
input_nodes=(mat1, mat2),
|
||||
layout=layout,
|
||||
**kwargs,
|
||||
**mm_options(config, m, n, k, layout),
|
||||
)
|
||||
|
||||
# using AutoHeuristic for ranking
|
||||
|
|
@ -832,16 +838,13 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
|||
choices = choices[:num_choices_before_extra_configs]
|
||||
|
||||
for k in inductor_config.external_matmul:
|
||||
choices.append(
|
||||
lazy_register_extern_choice(k).bind(kernel_inputs.nodes(), layout)
|
||||
)
|
||||
choices.append(lazy_register_extern_choice(k).bind((mat1, mat2), layout))
|
||||
|
||||
return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout)
|
||||
return autotune_select_algorithm(name, choices, [mat1, mat2], layout)
|
||||
|
||||
|
||||
@register_lowering(aten._int_mm, type_promotion_kind=None)
|
||||
def tuned_int_mm(mat1, mat2, *, layout=None):
|
||||
# TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
|
||||
m, n, k, layout, mat1, mat2 = mm_args(
|
||||
mat1, mat2, layout=layout, out_dtype=torch.int32
|
||||
)
|
||||
|
|
@ -858,6 +861,8 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
|
|||
layout,
|
||||
)
|
||||
|
||||
device_type = ir.get_device_type(mat1)
|
||||
|
||||
static_shape, is_nonzero = _is_static_problem(layout)
|
||||
use_cutlass = static_shape and is_nonzero and use_cutlass_template(layout, m, n, k)
|
||||
|
||||
|
|
@ -865,37 +870,33 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
|
|||
[aten__int_mm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
|
||||
)
|
||||
|
||||
# Create MMKernelInputs for Int MM
|
||||
kernel_inputs = MMKernelInputs([mat1, mat2])
|
||||
|
||||
if use_cutlass and _use_cutlass_for_op("int_mm"):
|
||||
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
|
||||
choices, layout, kernel_inputs.nodes(), fuseable=True, non_fuseable=True
|
||||
choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True
|
||||
)
|
||||
|
||||
int8_mm_configs = V.choices.get_int8_mm_configs(device_type)
|
||||
|
||||
if is_nonzero and use_triton_template(layout, enable_int32=True):
|
||||
for kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs, layout, mm_template.name, "int_mm"
|
||||
for config in int8_mm_configs(
|
||||
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
|
||||
):
|
||||
mm_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=kernel_inputs.nodes(),
|
||||
input_nodes=(mat1, mat2),
|
||||
layout=layout,
|
||||
**kwargs,
|
||||
**mm_options(config, m, n, k, layout),
|
||||
)
|
||||
|
||||
return autotune_select_algorithm("int_mm", choices, kernel_inputs.nodes(), layout)
|
||||
return autotune_select_algorithm("int_mm", choices, [mat1, mat2], layout)
|
||||
|
||||
|
||||
@register_lowering(aten.addmm, type_promotion_kind=None)
|
||||
def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
||||
# TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
|
||||
device_type = ir.get_device_type(mat1)
|
||||
m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout)
|
||||
static_shape, is_nonzero = _is_static_problem(layout)
|
||||
|
||||
# Create MMKernelInputs for AddMM at the top
|
||||
kernel_inputs = MMKernelInputs([inp_expanded, mat1, mat2])
|
||||
|
||||
# below is for getting an overview logging info of inductor mms
|
||||
counters["aten_mm_info"][f"aten.addmm_{m}_{n}_{k}"] += 1
|
||||
log.info(
|
||||
|
|
@ -922,9 +923,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
|||
choices = (
|
||||
[
|
||||
aten_addmm.bind(
|
||||
# TODO(coconutruben): replace with kernel_inputs.nodes()
|
||||
# once that supports the unexpanded nodes as well
|
||||
[inp, mat1, mat2],
|
||||
(inp, mat1, mat2),
|
||||
layout,
|
||||
alpha=alpha,
|
||||
beta=beta,
|
||||
|
|
@ -933,19 +932,12 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
|||
if use_aten_gemm_kernels()
|
||||
else []
|
||||
)
|
||||
return autotune_select_algorithm(
|
||||
# TODO(coconutruben): replace with kernel_inputs.nodes()
|
||||
# once that supports the unexpanded nodes as well
|
||||
"addmm",
|
||||
choices,
|
||||
[inp, mat1, mat2],
|
||||
layout,
|
||||
)
|
||||
return autotune_select_algorithm("addmm", choices, [inp, mat1, mat2], layout)
|
||||
|
||||
choices = (
|
||||
[
|
||||
aten_addmm.bind(
|
||||
kernel_inputs.nodes(),
|
||||
(inp_expanded, mat1, mat2),
|
||||
layout,
|
||||
alpha=alpha,
|
||||
beta=beta,
|
||||
|
|
@ -965,42 +957,50 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
|||
choices.insert(
|
||||
0,
|
||||
aten_bias_addmm.bind(
|
||||
kernel_inputs.nodes(),
|
||||
layout,
|
||||
alpha=alpha,
|
||||
beta=beta,
|
||||
(inp_expanded, mat1, mat2), layout, alpha=alpha, beta=beta
|
||||
),
|
||||
)
|
||||
|
||||
mm_configs = V.choices.get_base_mm_configs(device_type)
|
||||
persistent_mm_configs = V.choices.get_persistent_mm_configs(device_type)
|
||||
|
||||
dtype = mat1.get_dtype()
|
||||
if is_nonzero and use_triton_template(layout):
|
||||
# Get template params using the new unified function
|
||||
for kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs, layout, mm_template.name, "addmm"
|
||||
for config in mm_configs(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
**mm_config_kwargs(device_type, _is_large_block_for_cpu, dtype.itemsize),
|
||||
):
|
||||
mm_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=kernel_inputs.nodes(),
|
||||
input_nodes=(inp_expanded, mat1, mat2),
|
||||
layout=layout,
|
||||
**kwargs,
|
||||
**mm_options(config, m, n, k, layout),
|
||||
prefix_args=1,
|
||||
epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
|
||||
epilogue_fn_hash=str(["addmm_epilogue", layout.dtype, alpha, beta]),
|
||||
)
|
||||
|
||||
if use_triton_tma_template(mat1, mat2):
|
||||
# Get TMA template params using the new unified function
|
||||
for kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs, layout, persistent_tma_mm_template.name, "addmm"
|
||||
for config in persistent_mm_configs(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
**mm_config_kwargs(
|
||||
device_type, _is_large_block_for_cpu, dtype.itemsize
|
||||
),
|
||||
):
|
||||
persistent_tma_mm_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=kernel_inputs.nodes(),
|
||||
input_nodes=(inp_expanded, mat1, mat2),
|
||||
layout=layout,
|
||||
workspace_arg=get_tma_workspace_arg(
|
||||
num_tma_descriptors=2,
|
||||
device=mat1.get_device(),
|
||||
),
|
||||
**kwargs,
|
||||
**mm_options(config, m, n, k, layout),
|
||||
**persistent_mm_options(mat1, mat2),
|
||||
prefix_args=1,
|
||||
epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
|
||||
)
|
||||
|
|
@ -1013,20 +1013,17 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
|||
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
|
||||
choices,
|
||||
layout,
|
||||
# reorder here because CUTLASS expects (x, w, bias) but torch
|
||||
# is bias, x, w
|
||||
kernel_inputs.nodes(reorder=[1, 2, 0]),
|
||||
[mat1, mat2, inp_expanded],
|
||||
alpha=alpha,
|
||||
beta=beta,
|
||||
input_reorder=[2, 0, 1],
|
||||
)
|
||||
|
||||
if is_nonzero and use_ck_gemm_template(layout, m, n, k):
|
||||
CKGemmTemplate.add_ck_gemm_choices(
|
||||
choices,
|
||||
layout,
|
||||
# reorder here because CK expects (x, w, bias) but torch
|
||||
# is bias, x, w
|
||||
kernel_inputs.nodes(reorder=[1, 2, 0]),
|
||||
[mat1, mat2, inp_expanded],
|
||||
alpha=alpha,
|
||||
beta=beta,
|
||||
input_reorder=[2, 0, 1],
|
||||
|
|
@ -1036,13 +1033,15 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
|||
CppGemmTemplate.add_choices(
|
||||
choices,
|
||||
layout,
|
||||
kernel_inputs.nodes(),
|
||||
[inp_expanded, mat1, mat2],
|
||||
alpha=alpha,
|
||||
beta=beta,
|
||||
has_bias=True,
|
||||
)
|
||||
|
||||
return autotune_select_algorithm("addmm", choices, kernel_inputs.nodes(), layout)
|
||||
return autotune_select_algorithm(
|
||||
"addmm", choices, [inp_expanded, mat1, mat2], layout
|
||||
)
|
||||
|
||||
|
||||
@register_lowering(aten._sparse_semi_structured_mm, type_promotion_kind=None)
|
||||
|
|
@ -1090,7 +1089,7 @@ def tuned_sparse_semi_structured_mm(
|
|||
)
|
||||
|
||||
return autotune_select_algorithm(
|
||||
"sparse_semi_structured_mm", choices, (mat1, mat1_meta, mat2), layout
|
||||
"sparse_semi_structured_mm", choices, [mat1, mat1_meta, mat2], layout
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1124,7 +1123,6 @@ def tuned_scaled_mm(
|
|||
Returns:
|
||||
Tensor: The result of the scaled matrix multiplication
|
||||
"""
|
||||
# TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
|
||||
m, n, k, layout, mat_a, mat_b = mm_args(
|
||||
mat_a, mat_b, layout=layout, out_dtype=out_dtype
|
||||
)
|
||||
|
|
@ -1140,6 +1138,7 @@ def tuned_scaled_mm(
|
|||
layout,
|
||||
)
|
||||
|
||||
device_type = ir.get_device_type(mat_a)
|
||||
check_supported_striding(mat_a, mat_b)
|
||||
|
||||
scale_a_real, scale_b_real = realize_inputs(scale_a, scale_b)
|
||||
|
|
@ -1166,51 +1165,59 @@ def tuned_scaled_mm(
|
|||
|
||||
_, is_nonzero = _is_static_problem(layout)
|
||||
|
||||
# Prepare triton input nodes and create kernel_inputs at the top
|
||||
triton_input_nodes: list[Any]
|
||||
if bias and len(mat_b.get_size()) == len(bias.get_size()) + 1:
|
||||
# Need to unsqueeze bias from [N] -> [1, N]
|
||||
triton_bias = L[aten.unsqueeze](bias, 0)
|
||||
else:
|
||||
triton_bias = bias
|
||||
|
||||
if len(scale_a.get_size()) == 0 or len(scale_b.get_size()) == 0:
|
||||
assert len(scale_a.get_size()) == len(scale_b.get_size())
|
||||
# Need to unsqueeze scale from [] -> [1, 1]
|
||||
triton_scale_a = L[aten.unsqueeze](L[aten.unsqueeze](scale_a, 0), 1)
|
||||
triton_scale_b = L[aten.unsqueeze](L[aten.unsqueeze](scale_b, 0), 1)
|
||||
else:
|
||||
triton_scale_a = scale_a
|
||||
triton_scale_b = scale_b
|
||||
|
||||
if bias:
|
||||
triton_input_nodes = [
|
||||
mat_a,
|
||||
mat_b,
|
||||
triton_scale_a,
|
||||
triton_scale_b,
|
||||
triton_bias,
|
||||
]
|
||||
suffix_args = 3
|
||||
else:
|
||||
triton_input_nodes = [mat_a, mat_b, triton_scale_a, triton_scale_b]
|
||||
suffix_args = 2
|
||||
|
||||
# Create MMKernelInputs for Scaled MM (matrices are at indices 0, 1)
|
||||
kernel_inputs = MMKernelInputs(triton_input_nodes, mat1_idx=0, mat2_idx=1)
|
||||
scaled_mm_configs = V.choices.get_scaled_mm_configs(device_type)
|
||||
scaled_persistent_mm_configs = V.choices.get_scaled_persistent_mm_configs(
|
||||
device_type
|
||||
)
|
||||
|
||||
if is_nonzero and use_triton_template(layout, enable_float8=True):
|
||||
triton_input_nodes: tuple[Any, ...]
|
||||
if bias and len(mat_b.get_size()) == len(bias.get_size()) + 1:
|
||||
# Need to unsqueeze bias from [N] -> [1, N]
|
||||
triton_bias = L[aten.unsqueeze](bias, 0)
|
||||
else:
|
||||
triton_bias = bias
|
||||
|
||||
if len(scale_a.get_size()) == 0 or len(scale_b.get_size()) == 0:
|
||||
assert len(scale_a.get_size()) == len(scale_b.get_size())
|
||||
# Need to unsqueeze scale from [] -> [1, 1]
|
||||
triton_scale_a = L[aten.unsqueeze](L[aten.unsqueeze](scale_a, 0), 1)
|
||||
triton_scale_b = L[aten.unsqueeze](L[aten.unsqueeze](scale_b, 0), 1)
|
||||
else:
|
||||
triton_scale_a = scale_a
|
||||
triton_scale_b = scale_b
|
||||
|
||||
if bias:
|
||||
triton_input_nodes = (
|
||||
mat_a,
|
||||
mat_b,
|
||||
triton_scale_a,
|
||||
triton_scale_b,
|
||||
triton_bias,
|
||||
)
|
||||
suffix_args = 3
|
||||
else:
|
||||
triton_input_nodes = (mat_a, mat_b, triton_scale_a, triton_scale_b)
|
||||
suffix_args = 2
|
||||
|
||||
# TODO (paulzhan): There is no template that exists for bias and TMA
|
||||
# Don't run tma template currently if bias exists
|
||||
if use_triton_tma_template(mat_a, mat_b) and not bias:
|
||||
# Get TMA template params using the new unified function
|
||||
for kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs, layout, scaled_mm_device_tma_template.name, "scaled_mm"
|
||||
):
|
||||
kwargs["USE_FAST_ACCUM"] = use_fast_accum
|
||||
for config in scaled_persistent_mm_configs(m, n, k):
|
||||
kwargs = scaled_mm_options(
|
||||
config,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
layout,
|
||||
scale_a,
|
||||
scale_b,
|
||||
use_fast_accum,
|
||||
device_tma=True,
|
||||
)
|
||||
scaled_mm_device_tma_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=kernel_inputs.nodes(),
|
||||
input_nodes=triton_input_nodes,
|
||||
layout=layout,
|
||||
workspace_arg=get_tma_workspace_arg(
|
||||
num_tma_descriptors=2,
|
||||
|
|
@ -1219,11 +1226,7 @@ def tuned_scaled_mm(
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
# Get template params using the new unified function
|
||||
for kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs, layout, mm_template.name, "scaled_mm"
|
||||
):
|
||||
kwargs["USE_FAST_ACCUM"] = use_fast_accum
|
||||
for config in scaled_mm_configs(m, n, k):
|
||||
if V.graph.sizevars.guard_or_false(sympy.Le(k, 16)):
|
||||
# Triton crashes however uncommon for real workloads
|
||||
continue
|
||||
|
|
@ -1233,10 +1236,13 @@ def tuned_scaled_mm(
|
|||
if using_b200() and V.graph.sizevars.guard_or_false(sympy.Lt(k, 32)):
|
||||
continue
|
||||
|
||||
kwargs = scaled_mm_options(
|
||||
config, m, n, k, layout, scale_a, scale_b, use_fast_accum
|
||||
)
|
||||
# possibly appends a TritonTemplateCaller to choices
|
||||
mm_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=kernel_inputs.nodes(),
|
||||
input_nodes=triton_input_nodes,
|
||||
layout=layout,
|
||||
**kwargs,
|
||||
suffix_args=suffix_args,
|
||||
|
|
@ -1252,16 +1258,14 @@ def tuned_scaled_mm(
|
|||
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
|
||||
choices,
|
||||
layout,
|
||||
kernel_inputs.nodes(), # type: ignore[arg-type]
|
||||
input_nodes, # type: ignore[arg-type]
|
||||
use_fast_accum=use_fast_accum, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
if is_nonzero and use_ck_gemm_template(layout, m, n, k):
|
||||
CKGemmTemplate.add_ck_gemm_choices(choices, layout, kernel_inputs.nodes())
|
||||
CKGemmTemplate.add_ck_gemm_choices(choices, layout, input_nodes)
|
||||
|
||||
return autotune_select_algorithm(
|
||||
"scaled_mm", choices, kernel_inputs.nodes(), layout
|
||||
)
|
||||
return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout)
|
||||
|
||||
|
||||
@functools.cache
|
||||
|
|
|
|||
|
|
@ -3,13 +3,17 @@ import logging
|
|||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch._inductor.select_algorithm import realize_inputs, SymbolicGridFn
|
||||
from torch._inductor.utils import sympy_product
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
from .. import config as inductor_config
|
||||
from ..codegen.wrapper import PythonWrapperCodegen
|
||||
from ..ir import _IntLike, Layout, TensorBox
|
||||
from ..utils import get_num_sms, TMA_DESCRIPTOR_SIZE
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -45,6 +49,96 @@ def acc_type(dtype):
|
|||
return f"tl.{dtype}".replace("torch.", "")
|
||||
|
||||
|
||||
def mm_options(config, sym_m, sym_n, sym_k, layout):
|
||||
"""
|
||||
Common options to matmul triton templates.
|
||||
"""
|
||||
even_k_symbolic = (
|
||||
# it isn't worth guarding on this
|
||||
sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"]
|
||||
)
|
||||
allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and (
|
||||
not inductor_config.force_same_precision
|
||||
or ((sym_m % 16) == 0 and (sym_n % 16) == 0 and (sym_k % 8) == 0)
|
||||
)
|
||||
options_dict = dict(
|
||||
EVEN_K=even_k_symbolic,
|
||||
ALLOW_TF32=allow_tf32,
|
||||
USE_FAST_ACCUM=False, # Option for _scaled_mm
|
||||
ACC_TYPE=acc_type(layout.dtype),
|
||||
num_stages=config.num_stages,
|
||||
num_warps=config.num_warps,
|
||||
**config.kwargs,
|
||||
)
|
||||
|
||||
# If GROUP_M not specified then default to 8
|
||||
if "GROUP_M" not in config.kwargs:
|
||||
group_m = config.kwargs.get("GROUP_M", 8)
|
||||
options_dict["GROUP_M"] = group_m
|
||||
|
||||
return options_dict
|
||||
|
||||
|
||||
def tma_options() -> dict[str, Any]:
|
||||
from torch.utils._triton import has_triton_stable_tma_api
|
||||
|
||||
return {"TMA_EXPERIMENTAL_API": not has_triton_stable_tma_api()}
|
||||
|
||||
|
||||
def persistent_mm_options(mat1, mat2):
|
||||
res = {
|
||||
"A_ROW_MAJOR": not mat1.layout.is_transposed(),
|
||||
"B_ROW_MAJOR": not mat2.layout.is_transposed(),
|
||||
"NUM_SMS": get_num_sms(),
|
||||
"TMA_SIZE": TMA_DESCRIPTOR_SIZE,
|
||||
}
|
||||
res.update(tma_options())
|
||||
return res
|
||||
|
||||
|
||||
def scaled_mm_options( # type: ignore[no-untyped-def]
|
||||
config, # triton.Config
|
||||
sym_m: sympy.core.numbers.Integer,
|
||||
sym_n: sympy.core.numbers.Integer,
|
||||
sym_k: sympy.core.numbers.Integer,
|
||||
layout: Layout,
|
||||
scale_a,
|
||||
scale_b,
|
||||
use_fast_accum: bool,
|
||||
device_tma: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
def are_compatible_scales(size_a, size_b) -> bool:
|
||||
# Same sized scales are compatible
|
||||
if len(size_a) == len(size_b):
|
||||
return True
|
||||
|
||||
# Both need to be scalars or len(1) tensors
|
||||
if len(size_a) <= 1 and len(size_b) <= 1:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
size_a, size_b = scale_a.get_size(), scale_b.get_size()
|
||||
assert are_compatible_scales(size_a, size_b), (
|
||||
"Expect scale_a and scale_b to be either both scalars (including single-element tensors) "
|
||||
f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}."
|
||||
)
|
||||
|
||||
mm_template_options = mm_options(config, sym_m, sym_n, sym_k, layout)
|
||||
|
||||
mm_template_options["ACC_TYPE"] = "tl.float32"
|
||||
mm_template_options["USE_FAST_ACCUM"] = use_fast_accum
|
||||
mm_template_options["SCALING_ROWWISE"] = len(size_a) == 2
|
||||
|
||||
if device_tma:
|
||||
mm_template_options["TMA_SIZE"] = TMA_DESCRIPTOR_SIZE
|
||||
mm_template_options["NUM_SMS"] = get_num_sms()
|
||||
|
||||
mm_template_options.update(tma_options())
|
||||
|
||||
return mm_template_options
|
||||
|
||||
|
||||
def mm_args(
|
||||
mat1,
|
||||
mat2,
|
||||
|
|
@ -87,6 +181,20 @@ def mm_args(
|
|||
return [m, n, k, layout, mat1, mat2, *others]
|
||||
|
||||
|
||||
def mm_config_kwargs(device, exclude_condition, dtype_size=None):
|
||||
if device == "cpu":
|
||||
return {
|
||||
"scale": 0.5,
|
||||
"exclude": exclude_condition,
|
||||
}
|
||||
|
||||
if dtype_size and inductor_config.max_autotune_gemm_search_space == "EXHAUSTIVE":
|
||||
return {
|
||||
"dtype_size": dtype_size,
|
||||
}
|
||||
return {}
|
||||
|
||||
|
||||
def addmm_epilogue(dtype, alpha, beta):
|
||||
def epilogue(acc, bias):
|
||||
if alpha != 1:
|
||||
|
|
|
|||
|
|
@ -1,10 +1,8 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
from ..kernel_inputs import MMKernelInputs
|
||||
from .. import ir
|
||||
from ..lowering import lowerings
|
||||
from ..select_algorithm import (
|
||||
autotune_select_algorithm,
|
||||
|
|
@ -13,11 +11,9 @@ from ..select_algorithm import (
|
|||
)
|
||||
from ..utils import use_aten_gemm_kernels, use_triton_template
|
||||
from ..virtualized import V
|
||||
from .mm_common import mm_args, mm_grid
|
||||
from .mm_common import mm_args, mm_grid, mm_options
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
aten_mm_plus_mm = ExternKernelChoice(
|
||||
|
|
@ -123,9 +119,9 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
|
|||
"""
|
||||
Computes mm(mat1, mat2) + mm(mat3, mat4)
|
||||
"""
|
||||
# TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
|
||||
m1, n1, k1, layout1, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
|
||||
m2, n2, _, layout2, mat3, mat4 = mm_args(mat3, mat4, layout=layout)
|
||||
device_type = ir.get_device_type(mat1)
|
||||
|
||||
# Optimization is optional, because we can always just not do the fusion
|
||||
if (
|
||||
|
|
@ -144,34 +140,27 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
|
|||
lowerings[aten.mm](mat1, mat2), lowerings[aten.mm](mat3, mat4)
|
||||
)
|
||||
|
||||
# Create MMKernelInputs for MM Plus MM (matrices are at indices 0, 1 for first pair)
|
||||
# Note: This is a special case with 4 matrices, but we use the first pair for M, N, K extraction
|
||||
kernel_inputs = MMKernelInputs([mat1, mat2, mat3, mat4], mat1_idx=0, mat2_idx=1)
|
||||
|
||||
assert layout1 == layout2
|
||||
# options to tune from
|
||||
choices = (
|
||||
[aten_mm_plus_mm.bind(kernel_inputs.nodes(), layout1)]
|
||||
[aten_mm_plus_mm.bind((mat1, mat2, mat3, mat4), layout1)]
|
||||
if use_aten_gemm_kernels()
|
||||
else []
|
||||
)
|
||||
|
||||
mm_configs = V.choices.get_mm_plus_mm_configs(device_type)
|
||||
if use_triton_template(layout1):
|
||||
# Get template params using the new unified function
|
||||
for kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs, layout1, mm_plus_mm_template.name, "mm_plus_mm"
|
||||
):
|
||||
# Apply BLOCK_K constraint specific to mm_plus_mm
|
||||
for config in mm_configs():
|
||||
# see https://github.com/triton-lang/triton/issues/1298
|
||||
# BLOCK_K = K causes llvm error
|
||||
if V.graph.sizevars.statically_known_lt(kwargs.get("BLOCK_K", k1), k1):
|
||||
if V.graph.sizevars.statically_known_lt(config.kwargs["BLOCK_K"], k1):
|
||||
mm_plus_mm_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=kernel_inputs.nodes(),
|
||||
input_nodes=(mat1, mat2, mat3, mat4),
|
||||
layout=layout1,
|
||||
**kwargs,
|
||||
**mm_options(config, m1, n1, k1, layout1),
|
||||
)
|
||||
|
||||
return autotune_select_algorithm(
|
||||
"mm_plus_mm", choices, kernel_inputs.nodes(), layout1
|
||||
"mm_plus_mm", choices, [mat1, mat2, mat3, mat4], layout1
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,237 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch._inductor.config
|
||||
from torch._inductor import ir
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sympy
|
||||
|
||||
|
||||
class KernelInputs:
|
||||
"""
|
||||
Class to store and provide access to input nodes for kernels.
|
||||
This class takes in a tuple of input nodes and provides methods to access
|
||||
information about these nodes, such as their device type and device.
|
||||
"""
|
||||
|
||||
def __init__(self, input_nodes: list[Any]):
|
||||
"""
|
||||
Initialize with a tuple of input nodes.
|
||||
|
||||
Args:
|
||||
input_nodes: A tuple of input nodes to store
|
||||
"""
|
||||
self._input_nodes = input_nodes
|
||||
self._device_name: Optional[str] = None
|
||||
assert len(input_nodes) > 0, "Expected at least one input node"
|
||||
|
||||
def nodes(self, reorder: Optional[Sequence[int]] = None) -> list[Any]:
|
||||
"""
|
||||
Return the stored input nodes, optionally reordered.
|
||||
|
||||
Args:
|
||||
reorder: Optional sequence of indices to reorder the nodes.
|
||||
For example, (2, 0, 1) would return nodes in that order.
|
||||
|
||||
Returns:
|
||||
The tuple of input nodes, optionally reordered
|
||||
"""
|
||||
if reorder is None:
|
||||
return self._input_nodes
|
||||
assert len(self._input_nodes) == len(reorder), (
|
||||
f"Reorder length mismatch: {len(self._input_nodes)} vs {len(reorder)}"
|
||||
)
|
||||
return [self._input_nodes[i] for i in reorder]
|
||||
|
||||
@property
|
||||
def device_type(self) -> Optional[str]:
|
||||
"""
|
||||
Get the device type of the first node.
|
||||
|
||||
Returns:
|
||||
The device type (e.g., 'cuda', 'cpu')
|
||||
"""
|
||||
|
||||
return ir.get_device_type(self._input_nodes[0])
|
||||
|
||||
def device(self) -> torch.device:
|
||||
"""
|
||||
Get the device of the first node.
|
||||
|
||||
Returns:
|
||||
The device of the first node
|
||||
"""
|
||||
return self._input_nodes[0].get_device()
|
||||
|
||||
def device_name(self) -> Optional[str]:
|
||||
"""
|
||||
Get the device name information.
|
||||
|
||||
Returns:
|
||||
A tuple of (gpu_name, vendor, model)
|
||||
"""
|
||||
if self._device_name is None:
|
||||
device = self.device()
|
||||
if self.device_type == "cuda":
|
||||
device_properties = torch.cuda.get_device_properties(device)
|
||||
self._device_name = device_properties.gcnArchName
|
||||
return self._device_name
|
||||
|
||||
def shapes_symbolic(self) -> tuple[tuple[Any, ...], ...]:
|
||||
"""
|
||||
Get the symbolic shapes of all input nodes.
|
||||
|
||||
Returns:
|
||||
A tuple of shape tuples for each input node
|
||||
"""
|
||||
return tuple(node.get_size() for node in self._input_nodes)
|
||||
|
||||
def shapes_hinted(self) -> tuple[tuple[int, ...], ...]:
|
||||
"""
|
||||
Get the size hints for shapes of all input nodes.
|
||||
|
||||
Returns:
|
||||
A tuple of shape tuples with integer hints for each input node
|
||||
"""
|
||||
return tuple(
|
||||
V.graph.sizevars.size_hints(
|
||||
node.get_size(),
|
||||
fallback=torch._inductor.config.unbacked_symint_fallback,
|
||||
)
|
||||
for node in self._input_nodes
|
||||
)
|
||||
|
||||
def strides_symbolic(self) -> tuple[tuple[sympy.Integer, ...], ...]:
|
||||
"""
|
||||
Get the symbolic strides of all input nodes.
|
||||
|
||||
Returns:
|
||||
A tuple of stride tuples for each input node
|
||||
"""
|
||||
return tuple(node.get_stride() for node in self._input_nodes)
|
||||
|
||||
def strides_hinted(self) -> tuple[tuple[int, ...], ...]:
|
||||
"""
|
||||
Get the size hints for strides of all input nodes.
|
||||
|
||||
Returns:
|
||||
A tuple of stride tuples with integer hints for each input node
|
||||
"""
|
||||
return tuple(
|
||||
V.graph.sizevars.size_hints(
|
||||
node.get_stride(),
|
||||
fallback=torch._inductor.config.unbacked_symint_fallback,
|
||||
)
|
||||
for node in self._input_nodes
|
||||
)
|
||||
|
||||
def dtypes(self) -> tuple[torch.dtype, ...]:
|
||||
"""
|
||||
Get the dtypes of all input nodes.
|
||||
|
||||
Returns:
|
||||
A tuple of dtypes for each input node
|
||||
"""
|
||||
return tuple(node.get_dtype() for node in self._input_nodes)
|
||||
|
||||
def dtype(self, idx: int = 0) -> torch.dtype:
|
||||
"""
|
||||
Get the dtype of a specific input node.
|
||||
|
||||
Args:
|
||||
idx: Index of the node to get the dtype from (default: 0)
|
||||
|
||||
Returns:
|
||||
The dtype of the specified input node
|
||||
"""
|
||||
return self._input_nodes[idx].get_dtype()
|
||||
|
||||
|
||||
class MMKernelInputs(KernelInputs):
|
||||
"""
|
||||
Specialized KernelInputs for matrix multiplication operations.
|
||||
Provides additional methods to access M, N, K dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, input_nodes: list[Any], mat1_idx: int = -2, mat2_idx: int = -1):
|
||||
"""
|
||||
Initialize with a tuple of input nodes.
|
||||
|
||||
By default, we assume the last 2 input nodes are mat1 and mat2, but
|
||||
the caller can adjust when necessary
|
||||
"""
|
||||
super().__init__(input_nodes)
|
||||
# for mm, we need at least 2 nodes, and we need to know which nodes
|
||||
# are the main matrixes e.g. addmm is (bias, mat1, mat2) whereas others
|
||||
# might be (mat1, mat2, scale), etc.
|
||||
assert len(self._input_nodes) >= 2, "Expected at least 2 input nodes"
|
||||
|
||||
# Adjust assertions to handle negative indices
|
||||
m1_idx, m2_idx = mat1_idx, mat2_idx
|
||||
if mat1_idx < 0:
|
||||
m1_idx += len(input_nodes)
|
||||
if mat2_idx < 0:
|
||||
m2_idx += len(input_nodes)
|
||||
|
||||
assert 0 <= m1_idx < len(input_nodes), f"Invalid mat1_idx: {mat1_idx}"
|
||||
assert 0 <= m1_idx < len(input_nodes), f"Invalid mat2_idx: {mat2_idx}"
|
||||
|
||||
self._mat1_idx = mat1_idx
|
||||
self._mat2_idx = mat2_idx
|
||||
|
||||
def mnk_symbolic(
|
||||
self,
|
||||
) -> tuple[sympy.Integer, sympy.Integer, sympy.Integer]:
|
||||
"""
|
||||
Get the symbolic M, N, K dimensions for matrix multiplication.
|
||||
Handles both 2D (MM) and 3D (BMM) tensors.
|
||||
|
||||
M is extracted from the second-to-last dimension of the first operand (mat1).
|
||||
N is extracted from the last dimension of the second operand (mat2).
|
||||
K is extracted from the last dimension of the first operand (mat1).
|
||||
|
||||
Returns:
|
||||
A tuple of (M, N, K) dimensions
|
||||
"""
|
||||
mat1 = self.nodes()[self._mat1_idx]
|
||||
mat2 = self.nodes()[self._mat2_idx]
|
||||
|
||||
m = mat1.get_size()[-2] # M from second-to-last dimension of mat1
|
||||
k = mat1.get_size()[-1] # K from last dimension of mat1
|
||||
n = mat2.get_size()[-1] # N from last dimension of mat2
|
||||
|
||||
# Ensure K dimensions match between operands
|
||||
k0 = mat2.get_size()[-2] # K from second-to-last dimension of mat2
|
||||
V.graph.sizevars.check_equals(k, k0)
|
||||
return (m, n, k)
|
||||
|
||||
def mnk_hinted(self) -> tuple[int, int, int]:
|
||||
"""
|
||||
Get the hinted M, N, K dimensions for matrix multiplication.
|
||||
Handles both 2D (MM) and 3D (BMM) tensors.
|
||||
|
||||
Uses shapes_hinted from the base class to get integer hints for dimensions.
|
||||
|
||||
Returns:
|
||||
A tuple of (M, N, K) dimensions as integers
|
||||
"""
|
||||
hinted_shapes = self.shapes_hinted()
|
||||
mat1_shape = hinted_shapes[self._mat1_idx]
|
||||
mat2_shape = hinted_shapes[self._mat2_idx]
|
||||
|
||||
m = mat1_shape[-2] # M from second-to-last dimension of mat1
|
||||
k = mat1_shape[-1] # K from last dimension of mat1
|
||||
n = mat2_shape[-1] # N from last dimension of mat2
|
||||
|
||||
# Ensure K dimensions match between operands
|
||||
k_check = mat2_shape[-2] # K from second-to-last dimension of mat2
|
||||
assert k == k_check, f"K dimensions don't match: {k} vs {k_check}"
|
||||
|
||||
return (m, n, k)
|
||||
|
|
@ -7,16 +7,11 @@ from functools import partial
|
|||
from threading import Lock
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._triton import has_triton_stable_tma_api
|
||||
|
||||
from . import config, config as inductor_config
|
||||
from .kernel_inputs import KernelInputs, MMKernelInputs
|
||||
from .template_registry import register_template_heuristic
|
||||
from .utils import get_backend_num_stages, get_num_sms, TMA_DESCRIPTOR_SIZE
|
||||
from . import config
|
||||
from .utils import get_backend_num_stages
|
||||
from .virtualized import V
|
||||
|
||||
|
||||
|
|
@ -152,9 +147,6 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
|
|||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Whether the heuristic is used for int8. Use this when the heuristic is int8 exclusive
|
||||
# but prefer the preprocess_mm_configs argument when it's used for both
|
||||
self.has_int8_tensor: bool = False
|
||||
# List of dictionaries to store the kernel configs. Configs that evaluate to true
|
||||
# will be utilised on the target platform. The configs are as follows:
|
||||
# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
|
||||
|
|
@ -475,7 +467,7 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
|
|||
configs: list[BaseConfig],
|
||||
scale: float,
|
||||
has_int8_tensor: bool,
|
||||
exclude: Callable[[sympy.Integer, sympy.Integer, sympy.Integer], bool],
|
||||
exclude: Callable[[int, int, int], bool],
|
||||
hint_override: Optional[int] = None,
|
||||
) -> list[BaseConfig]:
|
||||
"""
|
||||
|
|
@ -484,7 +476,7 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
|
|||
from .runtime.runtime_utils import next_power_of_2
|
||||
|
||||
min_block_size = 16
|
||||
min_block_size_k = 32 if (has_int8_tensor or self.has_int8_tensor) else 16
|
||||
min_block_size_k = 32 if has_int8_tensor else 16
|
||||
|
||||
scaled_configs = []
|
||||
for hint_override in [None] + config.multi_kernel_hints:
|
||||
|
|
@ -569,13 +561,6 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
|
|||
|
||||
return pruned_configs
|
||||
|
||||
def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]:
|
||||
"""
|
||||
Filter configs based on specific requirements.
|
||||
Subclasses can override this to implement custom filtering logic.
|
||||
"""
|
||||
return configs
|
||||
|
||||
def preprocess_mm_configs(
|
||||
self,
|
||||
m: int,
|
||||
|
|
@ -583,17 +568,14 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
|
|||
k: int,
|
||||
configs: list[BaseConfig],
|
||||
has_int8_tensor: bool = False,
|
||||
scale: float = 1.0,
|
||||
exclude: Callable[
|
||||
[sympy.Integer, sympy.Integer, sympy.Integer], bool
|
||||
] = lambda m, n, k: False,
|
||||
scale: int = 1,
|
||||
exclude: Callable[[int, int, int], bool] = lambda m, n, k: False,
|
||||
dtype_size: int = 0,
|
||||
op_name: str = "mm", # For preprocessing overrides e.g. on CPU
|
||||
) -> Generator[TritonConfig, None, None]:
|
||||
configs = self._filter_configs(configs)
|
||||
scaled_configs = self._scale_mm_configs(
|
||||
m, n, k, configs, scale, has_int8_tensor, exclude
|
||||
)
|
||||
|
||||
if config.max_autotune_gemm_search_space == "EXHAUSTIVE":
|
||||
assert dtype_size > 0, "dtype_size must be provided for exhaustive search"
|
||||
scaled_configs = self._prune_exhaustive_configs(scaled_configs, dtype_size)
|
||||
|
|
@ -612,10 +594,48 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
|
|||
def get_exhaustive_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
return partial(self.preprocess_mm_configs, configs=self.exhaustive_configs)
|
||||
|
||||
def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
return partial(
|
||||
self.preprocess_mm_configs, configs=self.conv_configs, op_name="conv"
|
||||
def get_extra_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
return partial(self.preprocess_mm_configs, configs=self.extra_mm_configs)
|
||||
|
||||
def get_int8_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
return partial(self.preprocess_mm_configs, configs=self.int8_mm_configs)
|
||||
|
||||
def get_mixed_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
mm_configs = (
|
||||
self.mm_configs + self.mixed_mm_configs
|
||||
if config.max_autotune_gemm_search_space == "EXHAUSTIVE"
|
||||
else self.mm_configs
|
||||
)
|
||||
return partial(self.preprocess_mm_configs, configs=mm_configs)
|
||||
|
||||
def get_persistent_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
persistent_mm_configs = (
|
||||
self.exhaustive_configs
|
||||
if config.max_autotune_gemm_search_space == "EXHAUSTIVE"
|
||||
else self.persistent_mm_configs
|
||||
)
|
||||
|
||||
# num_warps=2 not safe for TMA
|
||||
persistent_mm_configs = [
|
||||
config for config in persistent_mm_configs if config.num_warps != 2
|
||||
]
|
||||
return partial(self.preprocess_mm_configs, configs=persistent_mm_configs)
|
||||
|
||||
def get_scaled_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
return partial(self.preprocess_mm_configs, configs=self.scaled_mm_configs)
|
||||
|
||||
def get_scaled_persistent_mm_configs(
|
||||
self,
|
||||
) -> partial[Generator[TritonConfig, None, None]]:
|
||||
return partial(
|
||||
self.preprocess_mm_configs, configs=self.scaled_persistent_mm_configs
|
||||
)
|
||||
|
||||
def get_mm_plus_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
return partial(self._finalize_mm_configs, configs=self.mm_plus_mm_configs)
|
||||
|
||||
def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
return partial(self.preprocess_mm_configs, configs=self.conv_configs)
|
||||
|
||||
# Flex attn helpers
|
||||
def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]:
|
||||
|
|
@ -676,80 +696,7 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
|
|||
|
||||
|
||||
class CPUConfigHeuristic(BaseConfigHeuristic):
|
||||
"""
|
||||
CPU-specific config heuristic with CPU-specific optimizations.
|
||||
"""
|
||||
|
||||
def _get_cpu_exclude_function(
|
||||
self, method: str = "bmm"
|
||||
) -> Callable[[sympy.Integer, sympy.Integer, sympy.Integer], bool]:
|
||||
"""
|
||||
Get CPU-specific exclude function based on method type.
|
||||
Returns a function that can be used as exclude condition.
|
||||
Moved from mm_common._is_large_block_for_cpu and refactored to return a function.
|
||||
"""
|
||||
if method in ("conv"):
|
||||
|
||||
def exclude_conv(
|
||||
m: sympy.Integer, n: sympy.Integer, k: sympy.Integer
|
||||
) -> bool:
|
||||
# Thresholds are experimentally determined to reduce Triton CPU compile times
|
||||
if m > 256 or n > 256 or k > 256:
|
||||
return True
|
||||
return m * n * k > 2**17
|
||||
|
||||
return exclude_conv
|
||||
elif method in ("mm", "addmm", "int_mm"):
|
||||
|
||||
def exclude_mm(
|
||||
m: sympy.Integer, n: sympy.Integer, k: sympy.Integer
|
||||
) -> bool:
|
||||
return m * n > 2**13
|
||||
|
||||
return exclude_mm
|
||||
else: # Default to bmm implementation for unknown methods
|
||||
|
||||
def exclude_bmm(
|
||||
m: sympy.Integer, n: sympy.Integer, k: sympy.Integer
|
||||
) -> bool:
|
||||
if m > 128 or n > 128 or k > 128:
|
||||
return True
|
||||
return m * n > 2**12
|
||||
|
||||
return exclude_bmm
|
||||
|
||||
def preprocess_mm_configs(
|
||||
self,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
configs: list[BaseConfig],
|
||||
has_int8_tensor: bool = False,
|
||||
scale: float = 1.0,
|
||||
exclude: Callable[
|
||||
[sympy.Integer, sympy.Integer, sympy.Integer], bool
|
||||
] = lambda m, n, k: False,
|
||||
dtype_size: int = 0,
|
||||
op_name: str = "mm", # For preprocessing overrides e.g. on CPU
|
||||
) -> Generator[TritonConfig, None, None]:
|
||||
"""
|
||||
CPU-specific preprocessing that applies CPU-specific scaling (0.5) and exclusion logic.
|
||||
"""
|
||||
# Get CPU-specific exclude function based on operation type
|
||||
cpu_exclude_fn = self._get_cpu_exclude_function(op_name)
|
||||
|
||||
# Apply CPU-specific scaling (0.5) and exclusion logic
|
||||
return super().preprocess_mm_configs(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
configs=configs,
|
||||
has_int8_tensor=has_int8_tensor,
|
||||
scale=0.5,
|
||||
exclude=cpu_exclude_fn,
|
||||
dtype_size=dtype_size,
|
||||
op_name=op_name,
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
class CUDAConfigHeuristic(BaseConfigHeuristic):
|
||||
|
|
@ -1055,13 +1002,14 @@ class ROCmConfigHeuristic(BaseConfigHeuristic):
|
|||
for wpeu in [0, int(8 // num_warps)]
|
||||
]
|
||||
|
||||
def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]:
|
||||
"""
|
||||
ROCm specific filtering
|
||||
"""
|
||||
def _filter_configs(
|
||||
self, configs: list[BaseConfig], new_num_stages: int
|
||||
) -> list[BaseConfig]:
|
||||
# TODO: _filter_configs can be removed once backend specific configs are added
|
||||
# for all methods
|
||||
for c in configs:
|
||||
c.num_stages = self.default_num_stages
|
||||
return super()._filter_configs(configs)
|
||||
return configs
|
||||
|
||||
def _finalize_mm_configs(
|
||||
self,
|
||||
|
|
@ -1128,6 +1076,57 @@ class ROCmConfigHeuristic(BaseConfigHeuristic):
|
|||
kwargs["GROUP_M"] = group_m
|
||||
yield self.triton_config(**kwargs)
|
||||
|
||||
def get_extra_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
filtered_configs = self._filter_configs(
|
||||
self.extra_mm_configs, self.default_num_stages
|
||||
)
|
||||
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
||||
|
||||
def get_int8_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
filtered_configs = self._filter_configs(
|
||||
self.int8_mm_configs, self.default_num_stages
|
||||
)
|
||||
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
||||
|
||||
def get_mixed_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
mm_configs = (
|
||||
self.mm_configs + self.mixed_mm_configs
|
||||
if config.max_autotune_gemm_search_space == "EXHAUSTIVE"
|
||||
else self.mm_configs
|
||||
)
|
||||
filtered_configs = self._filter_configs(mm_configs, self.default_num_stages)
|
||||
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
||||
|
||||
def get_persistent_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
filtered_configs = self._filter_configs(
|
||||
self.persistent_mm_configs, self.default_num_stages
|
||||
)
|
||||
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
||||
|
||||
def get_scaled_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
filtered_configs = self._filter_configs(
|
||||
self.scaled_mm_configs, self.default_num_stages
|
||||
)
|
||||
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
||||
|
||||
def get_scaled_persistent_mm_configs(
|
||||
self,
|
||||
) -> partial[Generator[TritonConfig, None, None]]:
|
||||
filtered_configs = self._filter_configs(
|
||||
self.scaled_persistent_mm_configs, self.default_num_stages
|
||||
)
|
||||
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
||||
|
||||
def get_mm_plus_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
filtered_configs = self._filter_configs(self.mm_plus_mm_configs, 1)
|
||||
return partial(self._finalize_mm_configs, configs=filtered_configs)
|
||||
|
||||
def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
filtered_configs = self._filter_configs(
|
||||
self.conv_configs, self.default_num_stages
|
||||
)
|
||||
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
||||
|
||||
def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]:
|
||||
flex_attn_fwd_configs: list[FlexConfig] = []
|
||||
|
||||
|
|
@ -1208,643 +1207,3 @@ class MTIAConfigHeuristic(BaseConfigHeuristic):
|
|||
"""
|
||||
Placeholder child class for MTIA specific overrides.
|
||||
"""
|
||||
|
||||
|
||||
# Template-specific mixin classes
|
||||
|
||||
|
||||
class TemplateConfigHeuristics:
|
||||
def get_template_configs(
|
||||
self,
|
||||
kernel_inputs: KernelInputs,
|
||||
layout: Any,
|
||||
op_name: str,
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
"""
|
||||
Get template configs for the given inputs.
|
||||
This is the main entry point for template-specific logic.
|
||||
"""
|
||||
# NOTE: not an abstract class, because that clashed below for the mixin
|
||||
# functionality. Can be adjusted, but not a high priority
|
||||
yield from {}
|
||||
|
||||
|
||||
class MMTemplateConfigMixin(TemplateConfigHeuristics):
|
||||
"""
|
||||
Mixin class that converts config lists to template kwargs.
|
||||
This handles the logic that was previously in choices.get_mm_configs.
|
||||
|
||||
This mixin expects to be used with BaseConfigHeuristic or its subclasses.
|
||||
"""
|
||||
|
||||
# Type annotations to ensure the mixin works with BaseConfigHeuristic
|
||||
get_mm_configs: Callable[[], partial[Generator[TritonConfig, None, None]]]
|
||||
get_exhaustive_mm_configs: Callable[
|
||||
[], partial[Generator[TritonConfig, None, None]]
|
||||
]
|
||||
_filter_configs: Callable[[list[BaseConfig]], list[BaseConfig]]
|
||||
|
||||
def _get_config_generator(
|
||||
self,
|
||||
) -> partial[Generator[TritonConfig, None, None]]:
|
||||
"""
|
||||
Get the appropriate config generator based on search space.
|
||||
Can be overridden by subclasses for template-specific behavior.
|
||||
"""
|
||||
# Handle exhaustive search case
|
||||
if config.max_autotune_gemm_search_space == "EXHAUSTIVE":
|
||||
return self.get_exhaustive_mm_configs()
|
||||
else:
|
||||
return self.get_mm_configs()
|
||||
|
||||
def get_template_configs(
|
||||
self,
|
||||
kernel_inputs: KernelInputs,
|
||||
layout: Any,
|
||||
op_name: str,
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
"""
|
||||
Convert config lists to template kwargs.
|
||||
This replaces the logic from choices.get_mm_configs and inlines mm_options.
|
||||
"""
|
||||
assert isinstance(kernel_inputs, MMKernelInputs), (
|
||||
f"{self.__class__.__name__} requires MMKernelInputs"
|
||||
)
|
||||
input_nodes = kernel_inputs.nodes()
|
||||
if len(input_nodes) < 2:
|
||||
raise ValueError(f"Need at least 2 input tensors, got {len(input_nodes)}")
|
||||
|
||||
# Extract M, N, K from kernel_inputs
|
||||
m, n, k = kernel_inputs.mnk_symbolic()
|
||||
|
||||
# Extract dtype and device_type from kernel_inputs
|
||||
dtype = kernel_inputs.dtype()
|
||||
|
||||
# Get the appropriate config generator
|
||||
configs = self._get_config_generator()
|
||||
|
||||
# Generate and process configs
|
||||
for c in configs(m, n, k, dtype_size=dtype.itemsize, op_name=op_name):
|
||||
template_kwargs = self._convert_config_to_template_kwargs(
|
||||
c, m, n, k, layout
|
||||
)
|
||||
yield template_kwargs
|
||||
|
||||
def _convert_config_to_template_kwargs(
|
||||
self,
|
||||
triton_config: TritonConfig,
|
||||
m: sympy.Integer,
|
||||
n: sympy.Integer,
|
||||
k: sympy.Integer,
|
||||
layout: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Convert triton config to template kwargs.
|
||||
Moved from mm_common.mm_options.
|
||||
"""
|
||||
# Calculate EVEN_K symbolic
|
||||
even_k_symbolic = (
|
||||
# it isn't worth guarding on this
|
||||
sympy.gcd(k, triton_config.kwargs["BLOCK_K"])
|
||||
== triton_config.kwargs["BLOCK_K"]
|
||||
)
|
||||
|
||||
# Calculate allow_tf32
|
||||
allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and (
|
||||
not inductor_config.force_same_precision
|
||||
or ((m % 16) == 0 and (n % 16) == 0 and (k % 8) == 0)
|
||||
)
|
||||
|
||||
# Build options dict
|
||||
options_dict = dict(
|
||||
EVEN_K=even_k_symbolic,
|
||||
ALLOW_TF32=allow_tf32,
|
||||
USE_FAST_ACCUM=False, # Option for _scaled_mm
|
||||
ACC_TYPE=self._get_acc_type(layout.dtype),
|
||||
num_stages=triton_config.num_stages,
|
||||
num_warps=triton_config.num_warps,
|
||||
**triton_config.kwargs,
|
||||
)
|
||||
|
||||
# If GROUP_M not specified then default to 8
|
||||
if "GROUP_M" not in triton_config.kwargs:
|
||||
group_m = triton_config.kwargs.get("GROUP_M", 8)
|
||||
options_dict["GROUP_M"] = group_m
|
||||
|
||||
return options_dict
|
||||
|
||||
def _get_acc_type(self, dtype: torch.dtype) -> str:
|
||||
"""
|
||||
Get accumulator type for the given dtype.
|
||||
Moved from mm_common.acc_type.
|
||||
"""
|
||||
if dtype in (torch.float16, torch.bfloat16):
|
||||
return "tl.float32"
|
||||
return f"tl.{dtype}".replace("torch.", "")
|
||||
|
||||
|
||||
# INT8 specific mixin to filter correctly
|
||||
class INT8MMTemplateConfigMixin(MMTemplateConfigMixin):
|
||||
"""
|
||||
Ensure that we feed in has_int8_tensor=True
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.has_int8_tensor = True
|
||||
|
||||
|
||||
# TMA-specific mixin for TMA templates
|
||||
class TMAConfigMixin(MMTemplateConfigMixin):
|
||||
"""
|
||||
TMA-specific mixin that uses persistent configs and adds TMA options.
|
||||
This inherits from MMTemplateConfigMixin and overrides config generation.
|
||||
"""
|
||||
|
||||
def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]:
|
||||
"""
|
||||
TMA specific filtering, as num_warps=2 not safe for TMA
|
||||
"""
|
||||
configs = [c for c in configs if c.num_warps != 2]
|
||||
return super()._filter_configs(configs)
|
||||
|
||||
def get_template_configs(
|
||||
self,
|
||||
kernel_inputs: KernelInputs,
|
||||
layout: Any,
|
||||
op_name: str,
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
"""
|
||||
Generate TMA template configs by calling super and adding TMA-specific options.
|
||||
"""
|
||||
# Get base template configs from superclass
|
||||
for template_kwargs in super().get_template_configs(
|
||||
kernel_inputs, layout, op_name
|
||||
):
|
||||
# Add TMA-specific options (moved from mm_common.persistent_mm_options)
|
||||
input_nodes = kernel_inputs.nodes()
|
||||
self._add_tma_options(template_kwargs, input_nodes)
|
||||
yield template_kwargs
|
||||
|
||||
def _add_tma_options(
|
||||
self, template_kwargs: dict[str, Any], input_nodes: list[Any]
|
||||
) -> None:
|
||||
"""
|
||||
Add TMA-specific options to template kwargs.
|
||||
Moved from mm_common.persistent_mm_options and mm_common.tma_options.
|
||||
"""
|
||||
# For TMA templates, we need the actual matrix tensors
|
||||
mat1 = input_nodes[-2]
|
||||
mat2 = input_nodes[-1]
|
||||
|
||||
tma_opts = {
|
||||
"A_ROW_MAJOR": not mat1.layout.is_transposed(),
|
||||
"B_ROW_MAJOR": not mat2.layout.is_transposed(),
|
||||
"NUM_SMS": get_num_sms(),
|
||||
"TMA_SIZE": TMA_DESCRIPTOR_SIZE,
|
||||
"TMA_EXPERIMENTAL_API": not has_triton_stable_tma_api(),
|
||||
}
|
||||
template_kwargs.update(tma_opts)
|
||||
|
||||
|
||||
# Scaled MM-specific mixin for scaled MM templates (non-TMA)
|
||||
class ScaledMMConfigMixin(MMTemplateConfigMixin):
|
||||
"""
|
||||
Scaled MM-specific mixin that uses scaled configs and adds scaled MM options.
|
||||
This is for non-TMA scaled MM templates only.
|
||||
This inherits from MMTemplateConfigMixin and overrides config generation.
|
||||
"""
|
||||
|
||||
def get_template_configs(
|
||||
self,
|
||||
kernel_inputs: KernelInputs,
|
||||
layout: Any,
|
||||
op_name: str,
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
"""
|
||||
Generate scaled MM template configs with scaled MM-specific options.
|
||||
Handles the remaining logic from mm_common including assertions and SCALING_ROWWISE.
|
||||
"""
|
||||
input_nodes = kernel_inputs.nodes()
|
||||
|
||||
# Initial assertion from mm_common.scaled_mm_options
|
||||
assert len(input_nodes) >= 4, (
|
||||
f"scaled_mm requires at least 4 inputs, got {len(input_nodes)}"
|
||||
)
|
||||
|
||||
# Extract scale tensors (typically scale_a and scale_b are input_nodes[2] and input_nodes[3])
|
||||
scale_a = input_nodes[2]
|
||||
scale_b = input_nodes[3]
|
||||
|
||||
# Scale compatibility assertion from mm_common.scaled_mm_options
|
||||
def are_compatible_scales(size_a: Any, size_b: Any) -> bool:
|
||||
# Same sized scales are compatible
|
||||
if len(size_a) == len(size_b):
|
||||
return True
|
||||
|
||||
# Both need to be scalars or len(1) tensors
|
||||
if len(size_a) <= 1 and len(size_b) <= 1:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
size_a, size_b = scale_a.get_size(), scale_b.get_size()
|
||||
assert are_compatible_scales(size_a, size_b), (
|
||||
"Expect scale_a and scale_b to be either both scalars (including single-element tensors) "
|
||||
f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}."
|
||||
)
|
||||
|
||||
# Get base template configs from superclass
|
||||
for template_kwargs in super().get_template_configs(
|
||||
kernel_inputs, layout, op_name
|
||||
):
|
||||
# Add scaled MM-specific options (moved from mm_common.scaled_mm_options)
|
||||
# Override accumulator type for scaled MM
|
||||
template_kwargs["ACC_TYPE"] = "tl.float32"
|
||||
# Add SCALING_ROWWISE attribute based on scale_a tensor shape
|
||||
template_kwargs["SCALING_ROWWISE"] = len(size_a) == 2
|
||||
|
||||
yield template_kwargs
|
||||
|
||||
|
||||
# Scaled TMA-specific mixin for scaled MM templates with TMA
|
||||
class ScaledTMAConfigMixin(ScaledMMConfigMixin):
|
||||
"""
|
||||
Scaled TMA-specific mixin that extends ScaledMMConfigMixin with TMA functionality.
|
||||
This is for scaled MM templates that use device TMA.
|
||||
This inherits from ScaledMMConfigMixin and adds TMA-specific options.
|
||||
"""
|
||||
|
||||
def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]:
|
||||
"""
|
||||
TMA specific filtering, as num_warps=2 not safe for TMA
|
||||
"""
|
||||
configs = [c for c in configs if c.num_warps != 2]
|
||||
return super()._filter_configs(configs)
|
||||
|
||||
def get_template_configs(
|
||||
self,
|
||||
kernel_inputs: KernelInputs,
|
||||
layout: Any,
|
||||
op_name: str,
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
"""
|
||||
Generate scaled TMA template configs with both scaled MM and TMA-specific options.
|
||||
"""
|
||||
# Get base scaled MM template configs from superclass
|
||||
for template_kwargs in super().get_template_configs(
|
||||
kernel_inputs, layout, op_name
|
||||
):
|
||||
# Add TMA-specific options for device TMA scaled MM
|
||||
template_kwargs["TMA_SIZE"] = TMA_DESCRIPTOR_SIZE
|
||||
template_kwargs["NUM_SMS"] = get_num_sms()
|
||||
template_kwargs["TMA_EXPERIMENTAL_API"] = not has_triton_stable_tma_api()
|
||||
|
||||
yield template_kwargs
|
||||
|
||||
|
||||
# Template-specific heuristic classes using multiple inheritance
|
||||
|
||||
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic("mm", "cuda", register=torch.version.hip is None)
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic("bmm", "cuda", register=torch.version.hip is None)
|
||||
class CUDAMMTemplateConfigHeuristic(MMTemplateConfigMixin, CUDAConfigHeuristic):
|
||||
"""Standard MM template heuristic for CUDA"""
|
||||
|
||||
|
||||
# TODO(coconutruben): deprecate once autoheuristic is deprecated
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic("mm-ah", "cuda", register=torch.version.hip is None)
|
||||
class CUDAMMAHTemplateConfigHeuristic(MMTemplateConfigMixin, CUDAConfigHeuristic):
|
||||
"""Standard MM template heuristic for CUDA using the extra mm configs only (for autoheuristic)"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Override mm_configs to use scaled_mm_configs
|
||||
self.mm_configs = self.extra_mm_configs
|
||||
self.exhaustive_configs = self.extra_mm_configs
|
||||
|
||||
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic(
|
||||
"mm_persistent_tma", "cuda", register=torch.version.hip is None
|
||||
)
|
||||
class CUDAPersistentTMATemplateConfigHeuristic(TMAConfigMixin, CUDAConfigHeuristic):
|
||||
"""Persistent TMA template heuristic for CUDA"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Override mm_configs to use persistent_mm_configs
|
||||
self.mm_configs = self.persistent_mm_configs
|
||||
|
||||
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic(
|
||||
"mm", "cuda", register=torch.version.hip is None, op_name="scaled_mm"
|
||||
)
|
||||
class CUDAScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, CUDAConfigHeuristic):
|
||||
"""Scaled MM template heuristic for CUDA"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Override mm_configs to use scaled_mm_configs
|
||||
self.mm_configs = self.scaled_mm_configs
|
||||
# NOTE: overriding exhaustive configs here to be the same as mm_configs
|
||||
# as we haven't validated exhaustive support here yet
|
||||
# TODO(coconutruben): remove this once we have validated exhaustive support
|
||||
# for scaled_mm
|
||||
self.exhaustive_configs = self.scaled_mm_configs
|
||||
|
||||
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic(
|
||||
"scaled_mm_device_tma", "cuda", register=torch.version.hip is None
|
||||
)
|
||||
class CUDAScaledTMATemplateConfigHeuristic(ScaledTMAConfigMixin, CUDAConfigHeuristic):
|
||||
"""Scaled TMA template heuristic for CUDA"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Override mm_configs to use scaled_persistent_mm_configs for TMA
|
||||
self.mm_configs = self.scaled_persistent_mm_configs
|
||||
# NOTE: overriding exhaustive configs here to be the same as mm_configs
|
||||
# as we haven't validated exhaustive support here yet
|
||||
# TODO(coconutruben): remove this once we have validated exhaustive support
|
||||
# for scaled_mm
|
||||
self.exhaustive_configs = self.scaled_persistent_mm_configs
|
||||
|
||||
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic("mm_plus_mm", "cuda", register=torch.version.hip is None)
|
||||
class CUDAMMPlusMMTemplateConfigHeuristic(MMTemplateConfigMixin, CUDAConfigHeuristic):
|
||||
"""MM Plus MM template heuristic for CUDA"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Override mm_configs to use mm_plus_mm_configs
|
||||
self.mm_configs = self.mm_plus_mm_configs
|
||||
# NOTE: overriding exhaustive configs here to be the same as mm_configs
|
||||
# as we haven't validated exhaustive support here yet
|
||||
# TODO(coconutruben): remove this once we have validated exhaustive support
|
||||
# for scaled_mm
|
||||
self.exhaustive_configs = self.mm_plus_mm_configs
|
||||
|
||||
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic(
|
||||
"mm", "cuda", register=torch.version.hip is None, op_name="int_mm"
|
||||
)
|
||||
class CUDAInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, CUDAConfigHeuristic):
|
||||
"""Int8 MM template heuristic for CUDA"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Override mm_configs to use int8_mm_configs
|
||||
self.mm_configs = self.int8_mm_configs
|
||||
# NOTE: overriding exhaustive configs here to be the same as mm_configs
|
||||
# as we haven't validated exhaustive support here yet
|
||||
# TODO(coconutruben): remove this once we have validated exhaustive support
|
||||
# for scaled_mm
|
||||
self.exhaustive_configs = self.int8_mm_configs
|
||||
|
||||
|
||||
# ROCm template-specific classes
|
||||
|
||||
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic("mm", "cuda", register=torch.version.hip is not None)
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic("bmm", "cuda", register=torch.version.hip is not None)
|
||||
class ROCmMMTemplateConfigHeuristic(MMTemplateConfigMixin, ROCmConfigHeuristic):
|
||||
"""Standard MM template heuristic for ROCm"""
|
||||
|
||||
|
||||
# TODO(coconutruben): deprecate once autoheuristic is deprecated
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic("mm-ah", "cuda", register=torch.version.hip is not None)
|
||||
class ROCmMMAHTemplateConfigHeuristic(MMTemplateConfigMixin, ROCmConfigHeuristic):
|
||||
"""Standard MM template heuristic for ROCm using the extra mm configs only (for autoheuristic)"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Override mm_configs to use scaled_mm_configs
|
||||
self.mm_configs = self.extra_mm_configs
|
||||
self.exhaustive_configs = self.extra_mm_configs
|
||||
|
||||
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic(
|
||||
"mm", "cuda", register=torch.version.hip is not None, op_name="scaled_mm"
|
||||
)
|
||||
class ROCmScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, ROCmConfigHeuristic):
|
||||
"""Scaled MM template heuristic for ROCm (non-TMA)"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Override mm_configs to use scaled_mm_configs
|
||||
self.mm_configs = self.scaled_mm_configs
|
||||
# NOTE: overriding exhaustive configs here to be the same as mm_configs
|
||||
# as we haven't validated exhaustive support here yet
|
||||
# TODO(coconutruben): remove this once we have validated exhaustive support
|
||||
# for scaled_mm
|
||||
self.exhaustive_configs = self.scaled_mm_configs
|
||||
|
||||
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic(
|
||||
"mm", "cuda", register=torch.version.hip is not None, op_name="int_mm"
|
||||
)
|
||||
class ROCmInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, ROCmConfigHeuristic):
|
||||
"""Int8 MM template heuristic for ROCm"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Override mm_configs to use int8_mm_configs
|
||||
self.mm_configs = self.int8_mm_configs
|
||||
# NOTE: overriding exhaustive configs here to be the same as mm_configs
|
||||
# as we haven't validated exhaustive support here yet
|
||||
# TODO(coconutruben): remove this once we have validated exhaustive support
|
||||
# for scaled_mm
|
||||
self.exhaustive_configs = self.int8_mm_configs
|
||||
|
||||
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic(
|
||||
"mm_plus_mm", "cuda", register=torch.version.hip is not None
|
||||
)
|
||||
class ROCmMMPlusMMTemplateConfigHeuristic(MMTemplateConfigMixin, ROCmConfigHeuristic):
|
||||
"""MM Plus MM template heuristic for ROCm"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Override mm_configs to use mm_plus_mm_configs
|
||||
self.mm_configs = self.mm_plus_mm_configs
|
||||
# NOTE: overriding exhaustive configs here to be the same as mm_configs
|
||||
# as we haven't validated exhaustive support here yet
|
||||
# TODO(coconutruben): remove this once we have validated exhaustive support
|
||||
# for scaled_mm
|
||||
self.exhaustive_configs = self.mm_plus_mm_configs
|
||||
|
||||
|
||||
# CPU template-specific classes
|
||||
|
||||
|
||||
@register_template_heuristic("mm", "cpu")
|
||||
@register_template_heuristic("bmm", "cpu")
|
||||
class CPUMMTemplateConfigHeuristic(MMTemplateConfigMixin, CPUConfigHeuristic):
|
||||
"""Standard MM template heuristic for CPU"""
|
||||
|
||||
|
||||
@register_template_heuristic("mm", "cpu", op_name="scaled_mm")
|
||||
class CPUScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, CPUConfigHeuristic):
|
||||
"""Scaled MM template heuristic for CPU (non-TMA)"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Override mm_configs to use scaled_mm_configs
|
||||
self.mm_configs = self.scaled_mm_configs
|
||||
# NOTE: overriding exhaustive configs here to be the same as mm_configs
|
||||
# as we haven't validated exhaustive support here yet
|
||||
# TODO(coconutruben): remove this once we have validated exhaustive support
|
||||
# for scaled_mm
|
||||
self.exhaustive_configs = self.scaled_mm_configs
|
||||
|
||||
|
||||
@register_template_heuristic("mm", "cpu", op_name="int_mm")
|
||||
class CPUInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, CPUConfigHeuristic):
|
||||
"""Int8 MM template heuristic for CPU"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Override mm_configs to use int8_mm_configs
|
||||
self.mm_configs = self.int8_mm_configs
|
||||
# NOTE: overriding exhaustive configs here to be the same as mm_configs
|
||||
# as we haven't validated exhaustive support here yet
|
||||
# TODO(coconutruben): remove this once we have validated exhaustive support
|
||||
# for scaled_mm
|
||||
self.exhaustive_configs = self.int8_mm_configs
|
||||
|
||||
|
||||
@register_template_heuristic("mm_plus_mm", "cpu")
|
||||
class CPUMMPlusMMTemplateConfigHeuristic(MMTemplateConfigMixin, CPUConfigHeuristic):
|
||||
"""MM Plus MM template heuristic for CPU"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Override mm_configs to use mm_plus_mm_configs
|
||||
self.mm_configs = self.mm_plus_mm_configs
|
||||
# NOTE: overriding exhaustive configs here to be the same as mm_configs
|
||||
# as we haven't validated exhaustive support here yet
|
||||
# TODO(coconutruben): remove this once we have validated exhaustive support
|
||||
# for scaled_mm
|
||||
self.exhaustive_configs = self.mm_plus_mm_configs
|
||||
|
||||
|
||||
# XPU template-specific classes
|
||||
|
||||
|
||||
@register_template_heuristic("mm", "xpu")
|
||||
@register_template_heuristic("bmm", "xpu")
|
||||
class XPUMMTemplateConfigHeuristic(MMTemplateConfigMixin, XPUConfigHeuristic):
|
||||
"""Standard MM template heuristic for XPU"""
|
||||
|
||||
|
||||
@register_template_heuristic("mm", "xpu", op_name="scaled_mm")
|
||||
class XPUScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, XPUConfigHeuristic):
|
||||
"""Scaled MM template heuristic for XPU (non-TMA)"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Override mm_configs to use scaled_mm_configs
|
||||
self.mm_configs = self.scaled_mm_configs
|
||||
# NOTE: overriding exhaustive configs here to be the same as mm_configs
|
||||
# as we haven't validated exhaustive support here yet
|
||||
# TODO(coconutruben): remove this once we have validated exhaustive support
|
||||
# for scaled_mm
|
||||
self.exhaustive_configs = self.scaled_mm_configs
|
||||
|
||||
|
||||
@register_template_heuristic("mm", "xpu", op_name="int_mm")
|
||||
class XPUInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, XPUConfigHeuristic):
|
||||
"""Int8 MM template heuristic for XPU"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Override mm_configs to use int8_mm_configs
|
||||
self.mm_configs = self.int8_mm_configs
|
||||
# NOTE: overriding exhaustive configs here to be the same as mm_configs
|
||||
# as we haven't validated exhaustive support here yet
|
||||
# TODO(coconutruben): remove this once we have validated exhaustive support
|
||||
# for scaled_mm
|
||||
self.exhaustive_configs = self.int8_mm_configs
|
||||
|
||||
|
||||
@register_template_heuristic("mm_plus_mm", "xpu")
|
||||
class XPUMMPlusMMTemplateConfigHeuristic(MMTemplateConfigMixin, XPUConfigHeuristic):
|
||||
"""MM Plus MM template heuristic for XPU"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Override mm_configs to use mm_plus_mm_configs
|
||||
self.mm_configs = self.mm_plus_mm_configs
|
||||
# NOTE: overriding exhaustive configs here to be the same as mm_configs
|
||||
# as we haven't validated exhaustive support here yet
|
||||
# TODO(coconutruben): remove this once we have validated exhaustive support
|
||||
# for scaled_mm
|
||||
self.exhaustive_configs = self.mm_plus_mm_configs
|
||||
|
||||
|
||||
# MTIA template-specific classes
|
||||
|
||||
|
||||
@register_template_heuristic("mm", "mtia")
|
||||
@register_template_heuristic("bmm", "mtia")
|
||||
class MTIAMMTemplateConfigHeuristic(MMTemplateConfigMixin, MTIAConfigHeuristic):
|
||||
"""Standard MM template heuristic for MTIA"""
|
||||
|
||||
|
||||
@register_template_heuristic("mm", "mtia", op_name="scaled_mm")
|
||||
class MTIAScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, MTIAConfigHeuristic):
|
||||
"""Scaled MM template heuristic for MTIA (non-TMA)"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Override mm_configs to use scaled_mm_configs
|
||||
self.mm_configs = self.scaled_mm_configs
|
||||
# NOTE: overriding exhaustive configs here to be the same as mm_configs
|
||||
# as we haven't validated exhaustive support here yet
|
||||
# TODO(coconutruben): remove this once we have validated exhaustive support
|
||||
# for scaled_mm
|
||||
self.exhaustive_configs = self.scaled_mm_configs
|
||||
|
||||
|
||||
@register_template_heuristic("mm", "mtia", op_name="int_mm")
|
||||
class MTIAInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, MTIAConfigHeuristic):
|
||||
"""Int8 MM template heuristic for MTIA"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Override mm_configs to use int8_mm_configs
|
||||
self.mm_configs = self.int8_mm_configs
|
||||
# NOTE: overriding exhaustive configs here to be the same as mm_configs
|
||||
# as we haven't validated exhaustive support here yet
|
||||
# TODO(coconutruben): remove this once we have validated exhaustive support
|
||||
# for scaled_mm
|
||||
self.exhaustive_configs = self.int8_mm_configs
|
||||
|
||||
|
||||
@register_template_heuristic("mm_plus_mm", "mtia")
|
||||
class MTIAMMPlusMMTemplateConfigHeuristic(MMTemplateConfigMixin, MTIAConfigHeuristic):
|
||||
"""MM Plus MM template heuristic for MTIA"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Override mm_configs to use mm_plus_mm_configs
|
||||
self.mm_configs = self.mm_plus_mm_configs
|
||||
# NOTE: overriding exhaustive configs here to be the same as mm_configs
|
||||
# as we haven't validated exhaustive support here yet
|
||||
# TODO(coconutruben): remove this once we have validated exhaustive support
|
||||
# for scaled_mm
|
||||
self.exhaustive_configs = self.mm_plus_mm_configs
|
||||
|
|
|
|||
|
|
@ -1,98 +0,0 @@
|
|||
"""
|
||||
Template heuristic registry system for PyTorch Inductor.
|
||||
|
||||
This module provides a centralized registration system for template heuristics,
|
||||
allowing automatic registration based on device type and conditional registration
|
||||
for CUDA vs ROCm based on torch.version.hip.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from functools import cache
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .template_heuristics import TemplateConfigHeuristics
|
||||
|
||||
# Module-wide registry for template heuristics
|
||||
_TEMPLATE_HEURISTIC_REGISTRY: dict[tuple[str, ...], type[TemplateConfigHeuristics]] = {}
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def register_template_heuristic(
|
||||
template_name: str,
|
||||
device_type: str,
|
||||
register: bool = True,
|
||||
op_name: Optional[str] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Decorator to register template heuristic classes.
|
||||
|
||||
Args:
|
||||
template_name: Name of the template (e.g., "mm", "bmm", "scaled_mm")
|
||||
device_type: Device type ("cuda", "cpu", "xpu")
|
||||
register: Whether to register this heuristic. Caller should pass the condition directly.
|
||||
op_name: Name of the operator (e.g., "mm", "bmm", "scaled_mm"). This is optional
|
||||
and is only used when a template uses different heuristics for different ops
|
||||
|
||||
Returns:
|
||||
Decorator function that registers the class if conditions are met.
|
||||
|
||||
Example:
|
||||
@register_template_heuristic("mm", "cuda", register=torch.version.hip is None)
|
||||
class CUDAMMTemplateConfigHeuristic(MMTemplateConfigMixin, CUDAConfigHeuristic):
|
||||
pass
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
cls: type[TemplateConfigHeuristics],
|
||||
) -> type[TemplateConfigHeuristics]:
|
||||
if register:
|
||||
key: tuple[str, ...] = (device_type, template_name)
|
||||
if op_name is not None:
|
||||
key = (device_type, template_name, op_name)
|
||||
_TEMPLATE_HEURISTIC_REGISTRY[key] = cls
|
||||
log.info(
|
||||
f"Registered template heuristic: {cls.__name__} for '{template_name=}', '{device_type=}', '{op_name=}'" # noqa: G004
|
||||
)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@cache
|
||||
def get_template_heuristic(
|
||||
template_name: str, device_type: str, op_name: str
|
||||
) -> TemplateConfigHeuristics:
|
||||
"""
|
||||
Retrieve a template heuristic instance for the given template and device type.
|
||||
|
||||
Args:
|
||||
template_name: Name of the template (e.g., "mm", "bmm", "scaled_mm")
|
||||
device_type: Device type ("cuda", "cpu", "xpu")
|
||||
|
||||
Returns:
|
||||
Template heuristic instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If no heuristic is found for the given combination.
|
||||
"""
|
||||
# First check the more specific key
|
||||
keys = [(device_type, template_name, op_name), (device_type, template_name)]
|
||||
|
||||
# Look up in registry
|
||||
heuristic_class = None
|
||||
for key in keys:
|
||||
if key in _TEMPLATE_HEURISTIC_REGISTRY:
|
||||
heuristic_class = _TEMPLATE_HEURISTIC_REGISTRY[key]
|
||||
break
|
||||
if heuristic_class is None:
|
||||
raise ValueError(
|
||||
f"No template heuristic found for '{template_name=}', "
|
||||
f"'{device_type=}', '{op_name=}'. "
|
||||
f"Available combinations: {list(_TEMPLATE_HEURISTIC_REGISTRY.keys())}"
|
||||
)
|
||||
return heuristic_class()
|
||||
Loading…
Reference in New Issue
Block a user