Reland "Introduce new template heuristic for triton autotune configs" (#147452)

This change was reverted in https://github.com/pytorch/pytorch/pull/147388 for regressing an internal workload.

I have removed the additional ir.device_type calls in mm_scaled and unpack_mixed_mm.py which could be contributing to the additional compile time.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147452
Approved by: https://github.com/jansel
This commit is contained in:
Jack Taylor 2025-03-26 15:47:06 +00:00 committed by PyTorch MergeBot
parent 7336b76bcc
commit 32299e5f9a
9 changed files with 739 additions and 640 deletions

View File

@ -1,20 +1,33 @@
from __future__ import annotations from __future__ import annotations
import typing import typing
from typing import Any, TYPE_CHECKING from typing import Any, Optional, TYPE_CHECKING
import sympy import sympy
import torch
from . import config from . import config
from .codecache import write_text from .codecache import write_text
from .metrics import get_metric_table, is_metric_table_enabled from .metrics import get_metric_table, is_metric_table_enabled
from .runtime.hints import DeviceProperties, ReductionHint from .runtime.hints import DeviceProperties, ReductionHint
from .scheduler import BaseSchedulerNode, Scheduler, WhyNoFuse from .scheduler import BaseSchedulerNode, Scheduler, WhyNoFuse
from .template_heuristics import (
BaseConfigHeuristic,
CPUConfigHeuristic,
CUDAConfigHeuristic,
ROCmConfigHeuristic,
XPUConfigHeuristic,
)
from .virtualized import V from .virtualized import V
if TYPE_CHECKING: if TYPE_CHECKING:
import torch from collections.abc import Generator
from functools import partial
from triton import Config as TritonConfig
from torch.utils._ordered_set import OrderedSet from torch.utils._ordered_set import OrderedSet
from .codegen.simd_kernel_features import SIMDKernelFeatures from .codegen.simd_kernel_features import SIMDKernelFeatures
@ -40,6 +53,80 @@ class InductorChoices:
torch._inductor.virtualized.V.set_choices_handler(MyHeuristics()) torch._inductor.virtualized.V.set_choices_handler(MyHeuristics())
""" """
def get_config_heuristics(
self, device_type: Optional[str] = "cuda"
) -> BaseConfigHeuristic:
if device_type == "cuda":
if torch.version.hip is None:
return CUDAConfigHeuristic()
else:
return ROCmConfigHeuristic()
elif device_type == "xpu":
return XPUConfigHeuristic()
elif device_type == "cpu":
return CPUConfigHeuristic()
else:
return BaseConfigHeuristic()
# GEMM configs
def get_base_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
if config.max_autotune_gemm_search_space != "EXHAUSTIVE":
return mm_heuristics.get_mm_configs()
else:
return mm_heuristics.get_exhaustive_mm_configs()
def get_extra_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_extra_mm_configs()
def get_int8_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_int8_mm_configs()
def get_mixed_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_mixed_mm_configs()
def get_persistent_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_persistent_mm_configs()
def get_scaled_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_scaled_mm_configs()
def get_scaled_persistent_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_scaled_persistent_mm_configs()
def get_mm_plus_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_mm_plus_mm_configs()
# Conv configs
def get_conv_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
conv_heuristics = self.get_config_heuristics(device_type)
return conv_heuristics.get_conv_configs()
def triton_kernel_kwargs( def triton_kernel_kwargs(
self, self,
kernel_cls: type[TritonKernel], kernel_cls: type[TritonKernel],

View File

@ -24,7 +24,7 @@ from .mm_common import (
_is_static_problem, _is_static_problem,
addmm_epilogue, addmm_epilogue,
mm_args, mm_args,
mm_configs, mm_config_kwargs,
mm_options, mm_options,
should_fallback_to_aten, should_fallback_to_aten,
) )
@ -46,12 +46,6 @@ def _is_large_block_for_cpu(m, n, k):
return m * n > 2**12 return m * n > 2**12
def bmm_configs(m, n, k, *, device_type):
if device_type == "cpu":
return mm_configs(m, n, k, scale=0.5, exclude=_is_large_block_for_cpu)
return mm_configs(m, n, k)
bmm_template = TritonTemplate( bmm_template = TritonTemplate(
name="bmm", name="bmm",
grid=bmm_grid, grid=bmm_grid,
@ -184,8 +178,14 @@ def tuned_bmm(mat1, mat2, *, layout=None):
# options to tune from # options to tune from
choices = [aten_bmm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else [] choices = [aten_bmm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
device_type = ir.get_device_type(mat1)
bmm_configs = V.choices.get_base_mm_configs(device_type)
if use_triton_template(layout): if use_triton_template(layout):
for config in bmm_configs(m, n, k, device_type=ir.get_device_type(mat1)): for config in bmm_configs(
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
):
bmm_template.maybe_append_choice( bmm_template.maybe_append_choice(
choices, choices,
input_nodes=(mat1, mat2), input_nodes=(mat1, mat2),
@ -239,8 +239,14 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
if use_aten_gemm_kernels() if use_aten_gemm_kernels()
else [] else []
) )
device_type = ir.get_device_type(mat1)
bmm_configs = V.choices.get_base_mm_configs(device_type)
if use_triton_template(layout): if use_triton_template(layout):
for config in bmm_configs(m, n, k, device_type=ir.get_device_type(mat1)): for config in bmm_configs(
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
):
bmm_template.maybe_append_choice( bmm_template.maybe_append_choice(
choices, choices,
input_nodes=(inp, mat1, mat2), input_nodes=(inp, mat1, mat2),

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import cast, Optional, TYPE_CHECKING, TypedDict from typing import Optional, TYPE_CHECKING, TypedDict
import torch import torch
from torch._inductor.codegen.rocm.ck_conv_template import CKGroupedConvFwdTemplate from torch._inductor.codegen.rocm.ck_conv_template import CKGroupedConvFwdTemplate
@ -29,7 +29,7 @@ from ..utils import (
use_triton_template, use_triton_template,
) )
from ..virtualized import V from ..virtualized import V
from .mm_common import build_rocm_gemm_configs, filtered_configs from .mm_common import mm_config_kwargs
if TYPE_CHECKING: if TYPE_CHECKING:
@ -61,31 +61,6 @@ def conv3d_grid(n, c, d, h, w, meta, *, cdiv):
) )
# List of dictionaries to store the kernel configs. Configs that evaluate to true
# will be utilised on the target platform
kernel_configs = [
# "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps"
{"config": (64, 256, 16, 2, 4), "cond": True},
{"config": (256, 64, 16, 2, 4), "cond": True},
{"config": (1024, 16, 16, 1, 8), "cond": True},
{"config": (128, 128, 32, 2, 8), "cond": True},
{"config": (64, 64, 32, 2, 4), "cond": True},
{"config": (64, 256, 32, 2, 8), "cond": True},
{"config": (256, 64, 32, 2, 8), "cond": True},
]
# Create filtered list of configs based on conv
platform_configs = tuple(
cast(tuple[int, int, int, int, int], config["config"])
for config in kernel_configs
if config["cond"]
)
# On ROCm convert num_stages to 1 as pipelining provides no benefit
if torch.version.hip and torch.cuda.is_available():
platform_configs = build_rocm_gemm_configs(platform_configs)
def _is_large_block_for_cpu(m, n, k): def _is_large_block_for_cpu(m, n, k):
# Thresholds are experimentally determined to reduce Triton CPU compile times # Thresholds are experimentally determined to reduce Triton CPU compile times
if m > 256 or n > 256 or k > 256: if m > 256 or n > 256 or k > 256:
@ -93,19 +68,6 @@ def _is_large_block_for_cpu(m, n, k):
return m * n * k > 2**17 return m * n * k > 2**17
def conv_configs(m, n, k, *, device_type, **kwargs):
if device_type == "cpu":
return filtered_configs(
m,
n,
k,
configs=platform_configs,
scale=0.5,
exclude=_is_large_block_for_cpu,
)
return filtered_configs(m, n, k, configs=platform_configs)
LOOP_BODY_2D = """ LOOP_BODY_2D = """
idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H
idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W
@ -497,6 +459,8 @@ def convolution(
"groups": groups, "groups": groups,
} }
device_type = ir.get_device_type(x)
if len(x.get_size()) == len(weight.get_size()) - 1: if len(x.get_size()) == len(weight.get_size()) - 1:
# add batch dimension to simplify rest of function # add batch dimension to simplify rest of function
return L[aten.squeeze]( return L[aten.squeeze](
@ -511,11 +475,7 @@ def convolution(
# Always convert conv1D to 2D for Intel GPU. # Always convert conv1D to 2D for Intel GPU.
# Only conv2D can be converted to channel last layout, # Only conv2D can be converted to channel last layout,
# which have much better performance. # which have much better performance.
if ( if len(x.get_size()) == 3 and len(kernel_shape) == 1 and device_type == "xpu":
len(x.get_size()) == 3
and len(kernel_shape) == 1
and ir.get_device_type(x) == "xpu"
):
kwargs.update( kwargs.update(
{ {
"stride": (1,) + stride, "stride": (1,) + stride,
@ -564,7 +524,7 @@ def convolution(
): ):
return convert_1x1_conv_to_mm(x, weight, bias) return convert_1x1_conv_to_mm(x, weight, bias)
if bias is not None and ir.get_device_type(x) != "cpu": if bias is not None and device_type != "cpu":
# peel off the bias, cudnn is slower with it # peel off the bias, cudnn is slower with it
result = convolution(x, weight, None, **kwargs) result = convolution(x, weight, None, **kwargs)
return L[aten.add]( return L[aten.add](
@ -639,11 +599,13 @@ def convolution(
): ):
choices.append(aten_conv1x1_via_mm.bind(args, layout)) choices.append(aten_conv1x1_via_mm.bind(args, layout))
conv_configs = V.choices.get_conv_configs(device_type)
for cfg in conv_configs( for cfg in conv_configs(
sympy_product([x.get_size()[0], *x.get_size()[2:]]), sympy_product([x.get_size()[0], *x.get_size()[2:]]),
out_chan, out_chan,
in_chan, in_chan,
device_type=ir.get_device_type(x), **mm_config_kwargs(device_type, _is_large_block_for_cpu),
): ):
if ndim == 2: if ndim == 2:
conv2d_template.maybe_append_choice( conv2d_template.maybe_append_choice(

View File

@ -28,7 +28,6 @@ from ..select_algorithm import (
TritonTemplate, TritonTemplate,
) )
from ..utils import ( from ..utils import (
get_gpu_shared_memory,
get_tma_workspace_arg, get_tma_workspace_arg,
use_aten_gemm_kernels, use_aten_gemm_kernels,
use_ck_gemm_template, use_ck_gemm_template,
@ -41,17 +40,13 @@ from ..utils import (
from .mm_common import ( from .mm_common import (
_is_static_problem, _is_static_problem,
addmm_epilogue, addmm_epilogue,
extra_mm_configs,
int8_mm_configs,
mm_args, mm_args,
mm_configs, mm_config_kwargs,
mm_grid, mm_grid,
mm_options, mm_options,
persistent_mm_configs,
persistent_mm_grid, persistent_mm_grid,
persistent_mm_options, persistent_mm_options,
should_fallback_to_aten, should_fallback_to_aten,
triton_config,
) )
@ -341,15 +336,6 @@ def _is_large_block_for_cpu(m, n, k):
return m * n > 2**13 return m * n > 2**13
def mm_config_kwargs(device):
if device == "cpu":
return {
"scale": 0.5,
"exclude": _is_large_block_for_cpu,
}
return {}
def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1): def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1):
""" """
Giving torch.addmm a 1D tensor calls a different (faster) cublasLt Giving torch.addmm a 1D tensor calls a different (faster) cublasLt
@ -367,6 +353,7 @@ aten_bias_addmm = ExternKernelChoice(bias_addmm, None)
@register_lowering(aten.mm, type_promotion_kind=None) @register_lowering(aten.mm, type_promotion_kind=None)
def tuned_mm(mat1, mat2, *, layout=None): def tuned_mm(mat1, mat2, *, layout=None):
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout) m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
device_type = ir.get_device_type(mat1)
name = "mm" name = "mm"
# below is for getting an overview logging info of inductor mms # below is for getting an overview logging info of inductor mms
@ -392,8 +379,15 @@ def tuned_mm(mat1, mat2, *, layout=None):
[aten_mm.bind((mat1, mat2), aten_layout)] if use_aten_gemm_kernels() else [] [aten_mm.bind((mat1, mat2), aten_layout)] if use_aten_gemm_kernels() else []
) )
static_shape, is_nonzero = _is_static_problem(layout) static_shape, is_nonzero = _is_static_problem(layout)
mm_configs = V.choices.get_base_mm_configs(device_type)
persistent_mm_configs = V.choices.get_persistent_mm_configs(device_type)
extra_mm_configs = V.choices.get_extra_mm_configs(device_type)
if is_nonzero and use_triton_template(layout): if is_nonzero and use_triton_template(layout):
for config in mm_configs(m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))): for config in mm_configs(
m, n, k, *mm_config_kwargs(device_type, _is_large_block_for_cpu)
):
mm_template.maybe_append_choice( mm_template.maybe_append_choice(
choices, choices,
input_nodes=(mat1, mat2), input_nodes=(mat1, mat2),
@ -402,7 +396,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
) )
if use_triton_tma_template(mat1, mat2): if use_triton_tma_template(mat1, mat2):
for config in persistent_mm_configs( for config in persistent_mm_configs(
m, n, k, **mm_config_kwargs(ir.get_device_type(mat1)) m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
): ):
persistent_tma_mm_template.maybe_append_choice( persistent_tma_mm_template.maybe_append_choice(
choices, choices,
@ -441,7 +435,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
always_included.append("extern_mm") always_included.append("extern_mm")
num_choices_before_extra_configs = len(choices) num_choices_before_extra_configs = len(choices)
for config in extra_mm_configs( for config in extra_mm_configs(
m, n, k, **mm_config_kwargs(ir.get_device_type(mat1)) m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
): ):
mm_template.maybe_append_choice( mm_template.maybe_append_choice(
choices, choices,
@ -503,6 +497,8 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
layout, layout,
) )
device_type = ir.get_device_type(mat1)
static_shape, is_nonzero = _is_static_problem(layout) static_shape, is_nonzero = _is_static_problem(layout)
use_cutlass = static_shape and is_nonzero and use_cutlass_template(layout, m, n, k) use_cutlass = static_shape and is_nonzero and use_cutlass_template(layout, m, n, k)
@ -514,9 +510,12 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices( CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True
) )
int8_mm_configs = V.choices.get_int8_mm_configs(device_type)
if is_nonzero and use_triton_template(layout, enable_int32=True): if is_nonzero and use_triton_template(layout, enable_int32=True):
for config in int8_mm_configs( for config in int8_mm_configs(
m, n, k, **mm_config_kwargs(ir.get_device_type(mat1)) m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
): ):
mm_template.maybe_append_choice( mm_template.maybe_append_choice(
choices, choices,
@ -534,6 +533,7 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
@register_lowering(aten.addmm, type_promotion_kind=None) @register_lowering(aten.addmm, type_promotion_kind=None)
def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
ordered_kwargs_for_cpp_kernel = ("beta", "alpha") ordered_kwargs_for_cpp_kernel = ("beta", "alpha")
device_type = ir.get_device_type(mat1)
m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout) m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout)
static_shape, is_nonzero = _is_static_problem(layout) static_shape, is_nonzero = _is_static_problem(layout)
@ -599,8 +599,13 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
), ),
) )
mm_configs = V.choices.get_base_mm_configs(device_type)
persistent_mm_configs = V.choices.get_persistent_mm_configs(device_type)
if is_nonzero and use_triton_template(layout): if is_nonzero and use_triton_template(layout):
for config in mm_configs(m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))): for config in mm_configs(
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
):
mm_template.maybe_append_choice( mm_template.maybe_append_choice(
choices, choices,
input_nodes=(inp_expanded, mat1, mat2), input_nodes=(inp_expanded, mat1, mat2),
@ -612,7 +617,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
if use_triton_tma_template(mat1, mat2): if use_triton_tma_template(mat1, mat2):
for config in persistent_mm_configs( for config in persistent_mm_configs(
m, n, k, **mm_config_kwargs(ir.get_device_type(mat1)) m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
): ):
persistent_tma_mm_template.maybe_append_choice( persistent_tma_mm_template.maybe_append_choice(
choices, choices,
@ -751,52 +756,6 @@ def dims_are_int(dims):
return all(isinstance(dim, int) for dim in dims) return all(isinstance(dim, int) for dim in dims)
def try_heuristic(m, n, k, choices, mat1, mat2, mat2_dtype, layout):
m, n, k = get_size_hints(mat1, mat2, m, n, k)
if not dims_are_int([m, n, k]):
return None
if mat1.dtype != torch.float16:
return None
# only use heuristic if we are running on an A100
# torch.cuda.get_device_capability() >= (8, 0) returns true for A10G
# which does not have enough shared memory for one of the configs
if (
not torch.cuda.get_device_capability() >= (8, 0)
) or get_gpu_shared_memory() != 166912:
return None
if m == 1 and (n % 16 != 0 or k % 16 != 0):
return None
if m <= 16 and n >= 4096 and k >= 4096:
return triton_config(
BLOCK_M=16,
BLOCK_N=64,
BLOCK_K=128,
num_stages=5,
num_warps=4,
)
elif m > 16 and m <= 32 and n >= 4096 and k >= 4096:
return triton_config(
BLOCK_M=32,
BLOCK_N=32,
BLOCK_K=128,
num_stages=5,
num_warps=4,
)
elif m > 32 and m <= 64 and n >= 4096 and k >= 4096:
return triton_config(
BLOCK_M=64,
BLOCK_N=32,
BLOCK_K=128,
num_stages=5,
num_warps=4,
)
return None
def mm_autoheuristic( def mm_autoheuristic(
mat1, mat1,
mat2, mat2,

View File

@ -1,441 +1,22 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import functools
import itertools
import logging import logging
from collections.abc import Sequence from typing import Any
from typing import Any, cast
import sympy import sympy
import torch import torch
from torch._inductor.select_algorithm import realize_inputs, SymbolicGridFn from torch._inductor.select_algorithm import realize_inputs, SymbolicGridFn
from torch._inductor.virtualized import V from torch._inductor.virtualized import V
from torch.utils._ordered_set import OrderedSet
from .. import config as inductor_config from .. import config as inductor_config
from ..codegen.wrapper import PythonWrapperCodegen from ..codegen.wrapper import PythonWrapperCodegen
from ..ir import ChoiceCaller, Layout from ..ir import ChoiceCaller, Layout
from ..runtime.runtime_utils import next_power_of_2 from ..utils import get_num_sms, TMA_DESCRIPTOR_SIZE, use_aten_gemm_kernels
from ..utils import (
get_backend_num_stages,
get_num_sms,
TMA_DESCRIPTOR_SIZE,
use_aten_gemm_kernels,
)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def triton_config(num_stages, num_warps, **kwargs):
from triton import Config # type: ignore[attr-defined]
return Config(kwargs, num_stages=num_stages, num_warps=num_warps)
def build_rocm_gemm_configs(configs):
rocm_num_stages = get_backend_num_stages()
return tuple((c[0], c[1], c[2], rocm_num_stages, c[4]) for c in configs)
def filtered_configs(
m: int,
n: int,
k: int,
configs: Sequence[tuple[int, int, int, int, int]],
has_int8_tensor=False,
scale=1,
exclude=lambda m, n, k: False,
):
"""
Heuristic to shrink configs when they are bigger than the input size
:param scale: scale factor applied to the config values
:param exclude: whether a given config should be excluded
"""
from torch._inductor import config
max_mm_configs = config.test_configs.max_mm_configs
min_block_size = 16
# block_k=16 seems to be causing issues
# see: https://github.com/triton-lang/triton/issues/2156#issuecomment-1695897424
min_block_size_k = 32 if has_int8_tensor else 16
m = max(
next_power_of_2(
V.graph.sizevars.size_hint(
m,
fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
)
),
min_block_size,
)
n = max(
next_power_of_2(
V.graph.sizevars.size_hint(
n,
fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
)
),
min_block_size,
)
k = max(
next_power_of_2(
V.graph.sizevars.size_hint(
k,
fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
)
),
min_block_size_k,
)
used = OrderedSet[tuple[int, ...]]()
for block_m, block_n, block_k, num_stages, num_warps in configs:
# shrink configs for small sizes
block_m = max(min(int(block_m * scale), m), min_block_size)
block_n = max(min(int(block_n * scale), n), min_block_size)
block_k = max(min(int(block_k * scale), k), min_block_size_k)
if exclude(block_m, block_n, block_k):
continue
# each warp computes 16x16 tile = 256
num_warps = min(num_warps, block_m * block_n // 256)
if torch.version.hip:
kpack = 2
for matrix_instr_nonkdim in [0, 16]:
if matrix_instr_nonkdim != 0 and (
block_m % matrix_instr_nonkdim != 0
or block_n % matrix_instr_nonkdim != 0
):
# block_m and block_n must be a multiple of matrix_instr_nonkdim
continue
if (
block_m,
block_n,
block_k,
num_stages,
num_warps,
matrix_instr_nonkdim,
kpack,
) not in used and (
max_mm_configs is None or len(used) < max_mm_configs
):
used.add(
(
block_m,
block_n,
block_k,
num_stages,
num_warps,
matrix_instr_nonkdim,
kpack,
)
)
yield triton_config(
BLOCK_M=block_m,
BLOCK_N=block_n,
BLOCK_K=block_k,
num_stages=num_stages,
num_warps=num_warps,
matrix_instr_nonkdim=matrix_instr_nonkdim,
kpack=kpack,
)
else:
if (block_m, block_n, block_k, num_stages, num_warps, 0) not in used and (
max_mm_configs is None or len(used) < max_mm_configs
):
used.add((block_m, block_n, block_k, num_stages, num_warps, 0))
yield triton_config(
BLOCK_M=block_m,
BLOCK_N=block_n,
BLOCK_K=block_k,
num_stages=num_stages,
num_warps=num_warps,
)
# List of dictionaries to store the kernel configs. Configs that evaluate to true
# will be utilised on the target platform. The configs are as follows:
# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
mm_kernel_configs = (
[
{"config": (32, 32, 16, 1, 2), "cond": True},
{"config": (32, 32, 128, 2, 4), "cond": True},
{"config": (32, 64, 32, 5, 8), "cond": True},
{"config": (64, 32, 32, 5, 8), "cond": True},
{"config": (64, 32, 128, 5, 4), "cond": True},
{"config": (64, 64, 16, 2, 4), "cond": True},
{"config": (64, 64, 32, 2, 4), "cond": True},
{"config": (64, 64, 64, 3, 8), "cond": True},
{"config": (64, 64, 128, 5, 4), "cond": True},
{"config": (64, 128, 32, 3, 4), "cond": True},
{"config": (64, 128, 32, 4, 8), "cond": True},
{"config": (64, 128, 64, 3, 4), "cond": True},
{"config": (64, 128, 128, 4, 4), "cond": True},
{"config": (128, 64, 32, 3, 4), "cond": True},
{"config": (128, 64, 32, 4, 8), "cond": True},
{"config": (128, 128, 32, 2, 8), "cond": True},
{"config": (128, 128, 32, 3, 4), "cond": True},
{"config": (128, 128, 64, 3, 4), "cond": True},
{"config": (128, 128, 64, 5, 8), "cond": True},
{"config": (128, 256, 64, 3, 8), "cond": True},
]
if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE"
else [
{"config": (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps), "cond": True}
for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
[16, 32, 64, 128, 256], repeat=3
)
for num_stages in [1, 2, 3, 4, 5]
for num_warps in [2, 4, 8]
]
)
# these are only used in tuned_mm when AutoHeuristic is enabled
# the idea is that when AutoHeuristic collects data to learn a heuristic, more configs are autotuned
# when the learned heuristic is used, the learned heuristic reduces the number of configs down to 10
# which saves compilation time (since less configs are autotuned) and potentially increase performance
# because the learned heuristic might predict a config that is not part mm_configs
extra_mm_kernel_configs = [
{"config": (16, 32, 16, 3, 2), "cond": True},
{"config": (16, 32, 32, 4, 2), "cond": True},
{"config": (16, 32, 32, 5, 2), "cond": True},
{"config": (64, 64, 128, 3, 4), "cond": True},
{"config": (128, 64, 32, 2, 2), "cond": True},
{"config": (128, 64, 64, 3, 8), "cond": True},
{"config": (128, 64, 128, 4, 8), "cond": True},
{"config": (128, 128, 32, 4, 4), "cond": True},
{"config": (128, 128, 64, 3, 8), "cond": True},
{"config": (128, 128, 64, 5, 4), "cond": True},
]
int8_mm_kernel_configs = [
{"config": (64, 64, 32, 2, 4), "cond": True},
{"config": (64, 128, 32, 3, 4), "cond": True},
{"config": (128, 64, 32, 3, 4), "cond": True},
{"config": (64, 128, 32, 4, 8), "cond": True},
{"config": (128, 64, 32, 4, 8), "cond": True},
{"config": (64, 32, 32, 5, 8), "cond": True},
{"config": (32, 64, 32, 5, 8), "cond": True},
{"config": (128, 128, 32, 2, 8), "cond": True},
{"config": (64, 64, 64, 3, 8), "cond": True},
# {"config": (32, 32, 128, 2, 4), "cond": True},
# {"config": (64, 64, 16, 2, 4), "cond": True},
# {"config": (32, 32, 16, 1, 2), "cond": True},
{"config": (128, 256, 128, 3, 8), "cond": True},
{"config": (256, 128, 128, 3, 8), "cond": True},
]
# Mixed precision kernel configs for small sizes of m for mm's like (16, 8192) x (8192, 8192).
mixed_mm_kernel_configs_small_m = [
{"config": (16, 128, 256, 3, 4), "cond": True},
{"config": (16, 128, 256, 5, 8), "cond": True},
]
mixed_mm_kernel_configs = (
mm_kernel_configs + mixed_mm_kernel_configs_small_m
if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE"
else mm_kernel_configs
)
persistent_mm_kernel_configs = [
{"config": (128, 256, 64, 3, 8), "cond": True},
{"config": (128, 128, 64, 3, 8), "cond": True},
{"config": (128, 128, 128, 3, 8), "cond": True},
{"config": (128, 128, 128, 3, 4), "cond": True},
{"config": (128, 128, 64, 4, 8), "cond": True},
]
scaled_mm_kernel_configs = [
{"config": (128, 256, 32, 3, 8), "cond": True},
{"config": (256, 128, 32, 3, 8), "cond": True},
{"config": (256, 64, 32, 4, 4), "cond": True},
{"config": (64, 256, 32, 4, 4), "cond": True},
{"config": (128, 128, 32, 4, 4), "cond": True},
{"config": (128, 64, 32, 4, 4), "cond": True},
{"config": (64, 128, 32, 4, 4), "cond": True},
{"config": (128, 32, 32, 4, 4), "cond": True},
{"config": (64, 32, 32, 5, 2), "cond": True},
{"config": (256, 128, 128, 3, 8), "cond": True},
{"config": (256, 64, 128, 4, 4), "cond": True},
{"config": (64, 256, 128, 4, 4), "cond": True},
{"config": (128, 128, 128, 4, 4), "cond": True},
{"config": (128, 64, 64, 4, 4), "cond": True},
{"config": (64, 128, 64, 4, 4), "cond": True},
{"config": (128, 32, 64, 4, 4), "cond": True},
{"config": (64, 32, 64, 5, 2), "cond": True},
{"config": (16, 32, 32, 2, 2), "cond": True},
{"config": (16, 64, 32, 2, 2), "cond": True},
{"config": (16, 128, 32, 2, 4), "cond": True},
{"config": (16, 256, 32, 2, 4), "cond": True},
{"config": (16, 32, 64, 2, 2), "cond": True},
{"config": (16, 64, 64, 2, 2), "cond": True},
{"config": (16, 128, 64, 2, 4), "cond": True},
{"config": (16, 256, 64, 2, 4), "cond": True},
{"config": (32, 32, 32, 2, 2), "cond": True},
{"config": (32, 64, 32, 2, 2), "cond": True},
{"config": (32, 128, 32, 2, 4), "cond": True},
{"config": (32, 256, 32, 2, 4), "cond": True},
{"config": (32, 32, 64, 2, 2), "cond": True},
{"config": (32, 64, 64, 2, 2), "cond": True},
{"config": (32, 128, 64, 2, 4), "cond": True},
{"config": (32, 256, 64, 2, 4), "cond": True},
{"config": (16, 32, 32, 3, 2), "cond": True},
{"config": (16, 64, 32, 3, 2), "cond": True},
{"config": (16, 128, 32, 3, 4), "cond": True},
{"config": (16, 256, 32, 3, 4), "cond": True},
{"config": (16, 32, 64, 3, 2), "cond": True},
{"config": (16, 64, 64, 3, 2), "cond": True},
{"config": (16, 128, 64, 3, 4), "cond": True},
{"config": (16, 256, 64, 3, 4), "cond": True},
{"config": (32, 32, 32, 3, 2), "cond": True},
{"config": (32, 64, 32, 3, 2), "cond": True},
{"config": (32, 128, 32, 3, 4), "cond": True},
{"config": (32, 256, 32, 3, 4), "cond": True},
{"config": (32, 32, 64, 3, 2), "cond": True},
{"config": (32, 64, 64, 3, 2), "cond": True},
{"config": (32, 128, 64, 3, 4), "cond": True},
{"config": (32, 256, 64, 3, 4), "cond": True},
{"config": (16, 32, 32, 4, 2), "cond": True},
{"config": (16, 64, 32, 4, 2), "cond": True},
{"config": (16, 128, 32, 4, 4), "cond": True},
{"config": (16, 256, 32, 4, 4), "cond": True},
{"config": (16, 32, 64, 4, 2), "cond": True},
{"config": (16, 64, 64, 4, 2), "cond": True},
{"config": (16, 128, 64, 4, 4), "cond": True},
{"config": (16, 256, 64, 4, 4), "cond": True},
{"config": (32, 32, 32, 4, 2), "cond": True},
{"config": (32, 64, 32, 4, 2), "cond": True},
{"config": (32, 128, 32, 4, 4), "cond": True},
{"config": (32, 256, 32, 4, 4), "cond": True},
{"config": (32, 32, 64, 4, 2), "cond": True},
{"config": (32, 64, 64, 4, 2), "cond": True},
{"config": (32, 128, 64, 4, 4), "cond": True},
{"config": (32, 256, 64, 4, 4), "cond": True},
{"config": (16, 32, 32, 5, 2), "cond": True},
{"config": (16, 64, 32, 5, 2), "cond": True},
{"config": (16, 128, 32, 5, 4), "cond": True},
{"config": (16, 256, 32, 5, 4), "cond": True},
{"config": (16, 32, 64, 5, 2), "cond": True},
{"config": (16, 64, 64, 5, 2), "cond": True},
{"config": (16, 128, 64, 5, 4), "cond": True},
{"config": (16, 256, 64, 5, 4), "cond": True},
{"config": (32, 32, 32, 5, 2), "cond": True},
{"config": (32, 64, 32, 5, 2), "cond": True},
{"config": (32, 128, 32, 5, 4), "cond": True},
{"config": (32, 256, 32, 5, 4), "cond": True},
{"config": (32, 32, 64, 5, 2), "cond": True},
{"config": (32, 64, 64, 5, 2), "cond": True},
{"config": (32, 128, 64, 5, 4), "cond": True},
{"config": (32, 256, 64, 5, 4), "cond": True},
{"config": (16, 32, 32, 6, 2), "cond": True},
{"config": (16, 64, 32, 6, 2), "cond": True},
{"config": (16, 128, 32, 6, 4), "cond": True},
{"config": (16, 256, 32, 6, 4), "cond": True},
{"config": (16, 32, 64, 6, 2), "cond": True},
{"config": (16, 64, 64, 6, 2), "cond": True},
{"config": (16, 128, 64, 6, 4), "cond": True},
{"config": (16, 256, 64, 6, 4), "cond": True},
{"config": (32, 32, 32, 6, 2), "cond": True},
{"config": (32, 64, 32, 6, 2), "cond": True},
{"config": (32, 128, 32, 6, 4), "cond": True},
{"config": (32, 256, 32, 6, 4), "cond": True},
{"config": (32, 32, 64, 6, 2), "cond": True},
{"config": (32, 64, 64, 6, 2), "cond": True},
{"config": (32, 128, 64, 6, 4), "cond": True},
{"config": (32, 256, 64, 6, 4), "cond": True},
]
scaled_persistent_mm_kernel_configs = [
{"config": (128, 128, 64, 3, 8), "cond": True},
{"config": (128, 128, 128, 3, 8), "cond": True},
{"config": (128, 128, 128, 4, 8), "cond": True},
{"config": (128, 128, 128, 4, 4), "cond": True},
{"config": (128, 128, 128, 3, 4), "cond": True},
{"config": (128, 128, 128, 5, 4), "cond": True},
{"config": (128, 128, 128, 5, 8), "cond": True},
{"config": (128, 128, 128, 6, 8), "cond": True},
{"config": (128, 128, 64, 4, 8), "cond": True},
]
# Create filtered list of configs based on cond evaluation
mm_platform_configs = tuple(
cast(tuple[int, int, int, int, int], config["config"])
for config in mm_kernel_configs
if config["cond"]
)
extra_mm_platform_configs = tuple(
cast(tuple[int, int, int, int, int], config["config"])
for config in extra_mm_kernel_configs
if config["cond"]
)
int8_platform_configs = tuple(
cast(tuple[int, int, int, int, int], config["config"])
for config in int8_mm_kernel_configs
if config["cond"]
)
mixed_mm_platform_configs = tuple(
cast(tuple[int, int, int, int, int], config["config"])
for config in mixed_mm_kernel_configs
if config["cond"]
)
persistent_mm_platform_configs = tuple(
cast(tuple[int, int, int, int, int], config["config"])
for config in persistent_mm_kernel_configs
if config["cond"]
)
scaled_mm_platform_configs = tuple(
cast(tuple[int, int, int, int, int], config["config"])
for config in scaled_mm_kernel_configs
if config["cond"]
)
scaled_persistent_mm_platform_configs = tuple(
cast(tuple[int, int, int, int, int], config["config"])
for config in scaled_persistent_mm_kernel_configs
if config["cond"]
)
# On ROCm convert num_stages to improve performance
if torch.version.hip and torch.cuda.is_available():
mm_platform_configs = build_rocm_gemm_configs(mm_platform_configs)
extra_mm_platform_configs = build_rocm_gemm_configs(extra_mm_platform_configs)
int8_platform_configs = build_rocm_gemm_configs(int8_platform_configs)
mixed_mm_platform_configs = build_rocm_gemm_configs(mixed_mm_platform_configs)
scaled_mm_platform_configs = build_rocm_gemm_configs(scaled_mm_platform_configs)
mm_configs = functools.partial(
filtered_configs,
configs=mm_platform_configs,
)
extra_mm_configs = functools.partial(
filtered_configs,
configs=extra_mm_platform_configs,
)
int8_mm_configs = functools.partial(
filtered_configs,
configs=int8_platform_configs,
)
persistent_mm_configs = functools.partial(
filtered_configs,
configs=persistent_mm_platform_configs,
)
scaled_mm_configs = functools.partial(
filtered_configs,
configs=scaled_mm_platform_configs,
)
scaled_persistent_mm_configs = functools.partial(
filtered_configs,
configs=scaled_persistent_mm_platform_configs,
)
def should_fallback_to_aten(choices: list[ChoiceCaller]) -> bool: def should_fallback_to_aten(choices: list[ChoiceCaller]) -> bool:
if len(choices) == 0 and not use_aten_gemm_kernels(): if len(choices) == 0 and not use_aten_gemm_kernels():
if inductor_config.autotune_fallback_to_aten: if inductor_config.autotune_fallback_to_aten:
@ -553,6 +134,15 @@ def mm_args(
return [m, n, k, layout, mat1, mat2, *others] return [m, n, k, layout, mat1, mat2, *others]
def mm_config_kwargs(device, exclude_condition):
if device == "cpu":
return {
"scale": 0.5,
"exclude": exclude_condition,
}
return {}
def addmm_epilogue(dtype, alpha, beta): def addmm_epilogue(dtype, alpha, beta):
def epilogue(acc, bias): def epilogue(acc, bias):
if alpha != 1: if alpha != 1:

View File

@ -1,8 +1,8 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import functools
import torch import torch
from .. import ir
from ..lowering import lowerings from ..lowering import lowerings
from ..select_algorithm import ( from ..select_algorithm import (
autotune_select_algorithm, autotune_select_algorithm,
@ -112,101 +112,14 @@ mm_plus_mm_template = TritonTemplate(
) )
@functools.lru_cache(None)
def mm_configs():
import triton
# List of dictionaries to store the kernel configs. Configs that evaluate to true
# will be utilised on the target platform
mm_triton_configs = [
{
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
"num_stages": 2,
"num_warps": 4,
"cond": True,
},
{
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
"num_stages": 3,
"num_warps": 8,
"cond": True,
},
{
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
"num_stages": 4,
"num_warps": 16,
"cond": True,
},
{
"config": {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32},
"num_stages": 4,
"num_warps": 8,
"cond": True,
},
{
"config": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32},
"num_stages": 4,
"num_warps": 8,
"cond": True,
},
{
"config": {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32},
"num_stages": 1,
"num_warps": 8,
"cond": True,
},
{
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64},
"num_stages": 1,
"num_warps": 8,
"cond": True,
},
{
"config": {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 128},
"num_stages": 1,
"num_warps": 8,
"cond": torch.version.hip is None,
},
{
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 16},
"num_stages": 2,
"num_warps": 4,
"cond": True,
},
{
"config": {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 16},
"num_stages": 1,
"num_warps": 2,
"cond": True,
},
]
# Filter out configs in which cond evaluates to true
# On ROCm convert num_stages to 1 as pipelining provides no benefit
if torch.version.hip:
filtered_configs = [
triton.Config(c["config"], num_stages=1, num_warps=c["num_warps"])
for c in mm_triton_configs
if c["cond"]
]
else:
filtered_configs = [
triton.Config(
c["config"], num_stages=c["num_stages"], num_warps=c["num_warps"]
)
for c in mm_triton_configs
if c["cond"]
]
return filtered_configs
def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None): def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
""" """
Computes mm(mat1, mat2) + mm(mat3, mat4) Computes mm(mat1, mat2) + mm(mat3, mat4)
""" """
m1, n1, k1, layout1, mat1, mat2 = mm_args(mat1, mat2, layout=layout) m1, n1, k1, layout1, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
m2, n2, _, layout2, mat3, mat4 = mm_args(mat3, mat4, layout=layout) m2, n2, _, layout2, mat3, mat4 = mm_args(mat3, mat4, layout=layout)
device_type = ir.get_device_type(mat1)
# Optimization is optional, because we can always just not do the fusion # Optimization is optional, because we can always just not do the fusion
if ( if (
m1 * n1 == 0 m1 * n1 == 0
@ -231,6 +144,9 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
if use_aten_gemm_kernels() if use_aten_gemm_kernels()
else [] else []
) )
mm_configs = V.choices.get_mm_plus_mm_configs(device_type)
if use_triton_template(layout1): if use_triton_template(layout1):
for config in mm_configs(): for config in mm_configs():
# see https://github.com/openai/triton/issues/1298 # see https://github.com/openai/triton/issues/1298

View File

@ -11,7 +11,7 @@ from torch._inductor.codegen.rocm.ck_universal_gemm_template import CKGemmTempla
from torch.utils._triton import has_triton_tma_device from torch.utils._triton import has_triton_tma_device
from ..config import triton as triton_config from ..config import triton as triton_config
from ..ir import _IntLike, ChoiceCaller, Layout, StorageBox, TensorBox from ..ir import _IntLike, ChoiceCaller, get_device_type, Layout, StorageBox, TensorBox
from ..lowering import add_layout_constraint, constrain_to_fx_strides, register_lowering from ..lowering import add_layout_constraint, constrain_to_fx_strides, register_lowering
from ..select_algorithm import ( from ..select_algorithm import (
autotune_select_algorithm, autotune_select_algorithm,
@ -27,13 +27,12 @@ from ..utils import (
use_ck_gemm_template, use_ck_gemm_template,
use_triton_template, use_triton_template,
) )
from ..virtualized import V
from .mm_common import ( from .mm_common import (
_is_static_problem, _is_static_problem,
mm_args, mm_args,
mm_grid, mm_grid,
persistent_mm_grid, persistent_mm_grid,
scaled_mm_configs,
scaled_persistent_mm_configs,
should_fallback_to_aten, should_fallback_to_aten,
) )
@ -508,6 +507,7 @@ def tuned_scaled_mm(
m, n, k, layout, mat_a, mat_b = mm_args( m, n, k, layout, mat_a, mat_b = mm_args(
mat_a, mat_b, layout=layout, out_dtype=out_dtype mat_a, mat_b, layout=layout, out_dtype=out_dtype
) )
# below is for getting an overview logging info of inductor mms # below is for getting an overview logging info of inductor mms
counters["aten_mm_info"][f"aten._scaled_mm.default_{m}_{n}_{k}"] += 1 counters["aten_mm_info"][f"aten._scaled_mm.default_{m}_{n}_{k}"] += 1
log.info( log.info(
@ -520,6 +520,8 @@ def tuned_scaled_mm(
layout, layout,
) )
device_type = get_device_type(mat_a)
check_supported_striding(mat_a, mat_b) check_supported_striding(mat_a, mat_b)
scale_a, scale_b = realize_inputs(scale_a, scale_b) scale_a, scale_b = realize_inputs(scale_a, scale_b)
@ -544,6 +546,11 @@ def tuned_scaled_mm(
_, is_nonzero = _is_static_problem(layout) _, is_nonzero = _is_static_problem(layout)
scaled_mm_configs = V.choices.get_scaled_mm_configs(device_type)
scaled_persistent_mm_configs = V.choices.get_scaled_persistent_mm_configs(
device_type
)
if is_nonzero and use_triton_template(layout, enable_float8=True): if is_nonzero and use_triton_template(layout, enable_float8=True):
if use_persistent_tma(k, bias is not None): if use_persistent_tma(k, bias is not None):
for config in scaled_persistent_mm_configs(m, n, k): for config in scaled_persistent_mm_configs(m, n, k):

View File

@ -44,7 +44,8 @@ def set_driver_to_gpu():
def get_backend_options(): def get_backend_options():
driver = triton.runtime.driver from triton.runtime import driver
target = driver.active.get_current_target() target = driver.active.get_current_target()
backend = triton.compiler.compiler.make_backend(target) backend = triton.compiler.compiler.make_backend(target)
options = backend.parse_options(dict()) options = backend.parse_options(dict())

View File

@ -0,0 +1,571 @@
from __future__ import annotations
import itertools
from collections import namedtuple
from functools import partial
from threading import Lock
from typing import Any, Callable, TYPE_CHECKING
from torch.utils._ordered_set import OrderedSet
from . import config
from .utils import get_backend_num_stages
from .virtualized import V
if TYPE_CHECKING:
from collections.abc import Generator, Sequence
from triton import Config as TritonConfig
class BaseConfigSingleton(type):
"""
Thread-safe implementation of single to be used in the config heuristic subclasses
to ensure heavy __init__ calls are not repeatedly run
"""
_instances: dict[type[Any], Any] = {}
_lock: Lock = Lock()
def __call__(
cls: BaseConfigSingleton, *args: Any, **kwargs: Any
) -> BaseConfigHeuristic:
with cls._lock:
if cls not in cls._instances:
instance = super().__call__()
cls._instances[cls] = instance
return cls._instances[cls]
Config = namedtuple(
"Config", ["block_m", "block_n", "block_k", "num_stages", "num_warps"]
)
class BaseConfigHeuristic(metaclass=BaseConfigSingleton):
"""
Base class for mm_configs, device specific triton kernels config inherit from here
"""
def __init__(self) -> None:
# List of dictionaries to store the kernel configs. Configs that evaluate to true
# will be utilised on the target platform. The configs are as follows:
# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
self.mm_configs = [
Config(32, 32, 16, 1, 2),
Config(32, 32, 128, 2, 4),
Config(32, 64, 32, 5, 8),
Config(64, 32, 32, 5, 8),
Config(64, 32, 128, 5, 4),
Config(64, 64, 16, 2, 4),
Config(64, 64, 32, 2, 4),
Config(64, 64, 64, 3, 8),
Config(64, 64, 128, 5, 4),
Config(64, 128, 32, 3, 4),
Config(64, 128, 32, 4, 8),
Config(64, 128, 64, 3, 4),
Config(64, 128, 128, 4, 4),
Config(128, 64, 32, 3, 4),
Config(128, 64, 32, 4, 8),
Config(128, 128, 32, 2, 8),
Config(128, 128, 32, 3, 4),
Config(128, 128, 64, 3, 4),
Config(128, 128, 64, 5, 8),
]
# Exhaustive search for mm configs
self.exhaustive_configs = [
Config(BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
[16, 32, 64, 128, 256], repeat=3
)
for num_stages in [1, 2, 3, 4, 5]
for num_warps in [2, 4, 8]
]
# these are only used in tuned_mm when AutoHeuristic is enabled
# the idea is that when AutoHeuristic collects data to learn a heuristic, more configs are autotuned
# when the learned heuristic is used, the learned heuristic reduces the number of configs down to 10
# which saves compilation time (since less configs are autotuned) and potentially increase performance
# because the learned heuristic might predict a config that is not part mm_configs
self.extra_mm_configs = [
Config(16, 32, 16, 3, 2),
Config(16, 32, 32, 4, 2),
Config(16, 32, 32, 5, 2),
Config(64, 64, 128, 3, 4),
Config(128, 64, 32, 2, 2),
Config(128, 64, 64, 3, 8),
Config(128, 64, 128, 4, 8),
Config(128, 128, 32, 4, 4),
Config(128, 128, 64, 3, 8),
Config(128, 128, 64, 5, 4),
]
self.int8_mm_configs = [
Config(64, 64, 32, 2, 4),
Config(64, 128, 32, 3, 4),
Config(128, 64, 32, 3, 4),
Config(64, 128, 32, 4, 8),
Config(128, 64, 32, 4, 8),
Config(64, 32, 32, 5, 8),
Config(32, 64, 32, 5, 8),
Config(128, 128, 32, 2, 8),
Config(64, 64, 64, 3, 8),
Config(128, 256, 128, 3, 8),
Config(256, 128, 128, 3, 8),
]
self.mixed_mm_configs = [
Config(16, 128, 256, 3, 4),
Config(16, 128, 256, 5, 8),
]
self.persistent_mm_configs = [
Config(128, 256, 64, 3, 8),
Config(128, 128, 64, 3, 8),
Config(128, 128, 128, 3, 8),
Config(128, 128, 128, 3, 4),
Config(128, 128, 64, 4, 8),
]
self.scaled_mm_configs = [
Config(128, 256, 32, 3, 8),
Config(256, 128, 32, 3, 8),
Config(256, 64, 32, 4, 4),
Config(64, 256, 32, 4, 4),
Config(128, 128, 32, 4, 4),
Config(128, 64, 32, 4, 4),
Config(64, 128, 32, 4, 4),
Config(128, 32, 32, 4, 4),
Config(64, 32, 32, 5, 2),
Config(256, 128, 128, 3, 8),
Config(256, 64, 128, 4, 4),
Config(64, 256, 128, 4, 4),
Config(128, 128, 128, 4, 4),
Config(128, 64, 64, 4, 4),
Config(64, 128, 64, 4, 4),
Config(128, 32, 64, 4, 4),
Config(64, 32, 64, 5, 2),
Config(16, 32, 32, 2, 2),
Config(16, 64, 32, 2, 2),
Config(16, 128, 32, 2, 4),
Config(16, 256, 32, 2, 4),
Config(16, 32, 64, 2, 2),
Config(16, 64, 64, 2, 2),
Config(16, 128, 64, 2, 4),
Config(16, 256, 64, 2, 4),
Config(32, 32, 32, 2, 2),
Config(32, 64, 32, 2, 2),
Config(32, 128, 32, 2, 4),
Config(32, 256, 32, 2, 4),
Config(32, 32, 64, 2, 2),
Config(32, 64, 64, 2, 2),
Config(32, 128, 64, 2, 4),
Config(32, 256, 64, 2, 4),
Config(16, 32, 32, 3, 2),
Config(16, 64, 32, 3, 2),
Config(16, 128, 32, 3, 4),
Config(16, 256, 32, 3, 4),
Config(16, 32, 64, 3, 2),
Config(16, 64, 64, 3, 2),
Config(16, 128, 64, 3, 4),
Config(16, 256, 64, 3, 4),
Config(32, 32, 32, 3, 2),
Config(32, 64, 32, 3, 2),
Config(32, 128, 32, 3, 4),
Config(32, 256, 32, 3, 4),
Config(32, 32, 64, 3, 2),
Config(32, 64, 64, 3, 2),
Config(32, 128, 64, 3, 4),
Config(32, 256, 64, 3, 4),
Config(16, 32, 32, 4, 2),
Config(16, 64, 32, 4, 2),
Config(16, 128, 32, 4, 4),
Config(16, 256, 32, 4, 4),
Config(16, 32, 64, 4, 2),
Config(16, 64, 64, 4, 2),
Config(16, 128, 64, 4, 4),
Config(16, 256, 64, 4, 4),
Config(32, 32, 32, 4, 2),
Config(32, 64, 32, 4, 2),
Config(32, 128, 32, 4, 4),
Config(32, 256, 32, 4, 4),
Config(32, 32, 64, 4, 2),
Config(32, 64, 64, 4, 2),
Config(32, 128, 64, 4, 4),
Config(32, 256, 64, 4, 4),
Config(16, 32, 32, 5, 2),
Config(16, 64, 32, 5, 2),
Config(16, 128, 32, 5, 4),
Config(16, 256, 32, 5, 4),
Config(16, 32, 64, 5, 2),
Config(16, 64, 64, 5, 2),
Config(16, 128, 64, 5, 4),
Config(16, 256, 64, 5, 4),
Config(32, 32, 32, 5, 2),
Config(32, 64, 32, 5, 2),
Config(32, 128, 32, 5, 4),
Config(32, 256, 32, 5, 4),
Config(32, 32, 64, 5, 2),
Config(32, 64, 64, 5, 2),
Config(32, 128, 64, 5, 4),
Config(32, 256, 64, 5, 4),
Config(16, 32, 32, 6, 2),
Config(16, 64, 32, 6, 2),
Config(16, 128, 32, 6, 4),
Config(16, 256, 32, 6, 4),
Config(16, 32, 64, 6, 2),
Config(16, 64, 64, 6, 2),
Config(16, 128, 64, 6, 4),
Config(16, 256, 64, 6, 4),
Config(32, 32, 32, 6, 2),
Config(32, 64, 32, 6, 2),
Config(32, 128, 32, 6, 4),
Config(32, 256, 32, 6, 4),
Config(32, 32, 64, 6, 2),
Config(32, 64, 64, 6, 2),
Config(32, 128, 64, 6, 4),
Config(32, 256, 64, 6, 4),
]
self.scaled_persistent_mm_configs = [
Config(128, 128, 64, 3, 8),
Config(128, 128, 128, 3, 8),
Config(128, 128, 128, 4, 8),
Config(128, 128, 128, 4, 4),
Config(128, 128, 128, 3, 4),
Config(128, 128, 128, 5, 4),
Config(128, 128, 128, 5, 8),
Config(128, 128, 128, 6, 8),
Config(128, 128, 64, 4, 8),
]
# TODO: Unify with other gemm patterns, mm_plus_mm currently follows
# slightly different pattern than rest
self.mm_plus_mm_configs = [
Config(64, 64, 32, 2, 4),
Config(64, 64, 32, 3, 8),
Config(64, 64, 32, 4, 16),
Config(64, 32, 32, 4, 8),
Config(32, 64, 32, 4, 8),
Config(128, 128, 32, 1, 8),
Config(64, 64, 64, 1, 8),
Config(32, 32, 128, 1, 8),
Config(64, 64, 16, 2, 4),
Config(32, 32, 16, 1, 2),
]
self.conv_configs = [
Config(64, 256, 16, 2, 4),
Config(256, 64, 16, 2, 4),
Config(1024, 16, 16, 1, 8),
Config(128, 128, 32, 2, 8),
Config(64, 64, 32, 2, 4),
Config(64, 256, 32, 2, 8),
Config(256, 64, 32, 2, 8),
]
def _finalize_mm_configs(
self,
configs: list[Config],
) -> Generator[TritonConfig, None, None]:
"""
Finalizes configs after scaling, applying additional constraints.
"""
used = OrderedSet[Config]()
max_mm_configs = config.test_configs.max_mm_configs
for block_m, block_n, block_k, num_stages, num_warps in configs:
# Each warp computes a 16x16 tile = 256 elements
num_warps = min(num_warps, block_m * block_n // 256)
if (
Config(block_m, block_n, block_k, num_stages, num_warps)
) not in used and (max_mm_configs is None or len(used) < max_mm_configs):
used.add(Config(block_m, block_n, block_k, num_stages, num_warps))
yield self.triton_config(
BLOCK_M=block_m,
BLOCK_N=block_n,
BLOCK_K=block_k,
num_stages=num_stages,
num_warps=num_warps,
)
def _scale_mm_configs(
self,
m: int,
n: int,
k: int,
configs: Sequence[Config],
scale: float,
has_int8_tensor: bool,
exclude: Callable[[int, int, int], bool],
) -> list[Config]:
"""
Scales and filters matrix multiplication configs based on input size.
"""
from .runtime.runtime_utils import next_power_of_2
min_block_size = 16
min_block_size_k = 32 if has_int8_tensor else 16
m = max(
next_power_of_2(
V.graph.sizevars.size_hint(
m,
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
)
),
min_block_size,
)
n = max(
next_power_of_2(
V.graph.sizevars.size_hint(
n,
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
)
),
min_block_size,
)
k = max(
next_power_of_2(
V.graph.sizevars.size_hint(
k,
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
)
),
min_block_size_k,
)
scaled_configs = []
for c in configs:
scaled_config = c._replace(
block_m=max(min(int(c.block_m * scale), m), min_block_size),
block_n=max(min(int(c.block_n * scale), n), min_block_size),
block_k=max(min(int(c.block_k * scale), k), min_block_size_k),
)
if not exclude(
scaled_config.block_m, scaled_config.block_n, scaled_config.block_k
):
scaled_configs.append(scaled_config)
return scaled_configs
def preprocess_mm_configs(
self,
m: int,
n: int,
k: int,
configs: Sequence[Config],
has_int8_tensor: bool = False,
scale: int = 1,
exclude: Callable[[int, int, int], bool] = lambda m, n, k: False,
) -> Generator[TritonConfig, None, None]:
scaled_configs = self._scale_mm_configs(
m, n, k, configs, scale, has_int8_tensor, exclude
)
return self._finalize_mm_configs(scaled_configs)
def triton_config(
self, num_stages: int, num_warps: int, **kwargs: Any
) -> TritonConfig:
from triton import Config as TritonConfig # type: ignore[attr-defined]
return TritonConfig(kwargs, num_stages=num_stages, num_warps=num_warps)
def get_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self.preprocess_mm_configs, configs=self.mm_configs)
def get_exhaustive_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self.preprocess_mm_configs, configs=self.exhaustive_configs)
def get_extra_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self.preprocess_mm_configs, configs=self.extra_mm_configs)
def get_int8_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self.preprocess_mm_configs, configs=self.int8_mm_configs)
def get_mixed_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
mm_configs = (
self.mm_configs + self.mixed_mm_configs
if config.max_autotune_gemm_search_space == "EXHAUSTIVE"
else self.mm_configs
)
return partial(self.preprocess_mm_configs, configs=mm_configs)
def get_persistent_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self.preprocess_mm_configs, configs=self.persistent_mm_configs)
def get_scaled_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self.preprocess_mm_configs, configs=self.scaled_mm_configs)
def get_scaled_persistent_mm_configs(
self,
) -> partial[Generator[TritonConfig, None, None]]:
return partial(
self.preprocess_mm_configs, configs=self.scaled_persistent_mm_configs
)
def get_mm_plus_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self._finalize_mm_configs, configs=self.mm_plus_mm_configs)
def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self.preprocess_mm_configs, configs=self.conv_configs)
class CPUConfigHeuristic(BaseConfigHeuristic):
pass
class CUDAConfigHeuristic(BaseConfigHeuristic):
pass
class ROCmConfigHeuristic(BaseConfigHeuristic):
def __init__(self) -> None:
super().__init__()
self.default_num_stages = get_backend_num_stages()
# Exhaustive search for mm configs
self.exhaustive_configs = [
Config(BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
[16, 32, 64, 128, 256], repeat=3
)
for num_stages in [1, self.default_num_stages]
for num_warps in [4, 8]
]
def _filter_configs(
self, configs: list[Config], new_num_stages: int
) -> list[Config]:
filtered_configs = [
c._replace(num_stages=self.default_num_stages) for c in configs
]
return filtered_configs
def _finalize_mm_configs(
self,
configs: list[Config],
) -> Generator[TritonConfig, None, None]:
used = OrderedSet[tuple[Config, int, int]]()
max_mm_configs = config.test_configs.max_mm_configs
for block_m, block_n, block_k, num_stages, num_warps in configs:
# each warp computes 16x16 tile = 256
num_warps = min(num_warps, block_m * block_n // 256)
kpack = 2
for matrix_instr_nonkdim in [0, 16]:
if matrix_instr_nonkdim != 0 and (
block_m % matrix_instr_nonkdim != 0
or block_n % matrix_instr_nonkdim != 0
):
# block_m and block_n must be a multiple of matrix_instr_nonkdim
continue
if (
Config(
block_m,
block_n,
block_k,
num_stages,
num_warps,
),
matrix_instr_nonkdim,
kpack,
) not in used and (
max_mm_configs is None or len(used) < max_mm_configs
):
used.add(
(
Config(
block_m,
block_n,
block_k,
num_stages,
num_warps,
),
matrix_instr_nonkdim,
kpack,
)
)
yield self.triton_config(
BLOCK_M=block_m,
BLOCK_N=block_n,
BLOCK_K=block_k,
num_stages=num_stages,
num_warps=num_warps,
matrix_instr_nonkdim=matrix_instr_nonkdim,
kpack=kpack,
)
def get_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.mm_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_exhaustive_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.exhaustive_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_extra_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.extra_mm_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_int8_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.int8_mm_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_mixed_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
mm_configs = (
self.mm_configs + self.mixed_mm_configs
if config.max_autotune_gemm_search_space == "EXHAUSTIVE"
else self.mm_configs
)
filtered_configs = self._filter_configs(mm_configs, self.default_num_stages)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_persistent_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.persistent_mm_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_scaled_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.scaled_mm_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_scaled_persistent_mm_configs(
self,
) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.scaled_persistent_mm_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_mm_plus_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(self.mm_plus_mm_configs, 1)
return partial(self._finalize_mm_configs, configs=filtered_configs)
def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.conv_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
class XPUConfigHeuristic(BaseConfigHeuristic):
pass