Reland "Introduce new template heuristic for triton autotune configs" (#147452)

This change was reverted in https://github.com/pytorch/pytorch/pull/147388 for regressing an internal workload.

I have removed the additional ir.device_type calls in mm_scaled and unpack_mixed_mm.py which could be contributing to the additional compile time.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147452
Approved by: https://github.com/jansel
This commit is contained in:
Jack Taylor 2025-03-26 15:47:06 +00:00 committed by PyTorch MergeBot
parent 7336b76bcc
commit 32299e5f9a
9 changed files with 739 additions and 640 deletions

View File

@ -1,20 +1,33 @@
from __future__ import annotations
import typing
from typing import Any, TYPE_CHECKING
from typing import Any, Optional, TYPE_CHECKING
import sympy
import torch
from . import config
from .codecache import write_text
from .metrics import get_metric_table, is_metric_table_enabled
from .runtime.hints import DeviceProperties, ReductionHint
from .scheduler import BaseSchedulerNode, Scheduler, WhyNoFuse
from .template_heuristics import (
BaseConfigHeuristic,
CPUConfigHeuristic,
CUDAConfigHeuristic,
ROCmConfigHeuristic,
XPUConfigHeuristic,
)
from .virtualized import V
if TYPE_CHECKING:
import torch
from collections.abc import Generator
from functools import partial
from triton import Config as TritonConfig
from torch.utils._ordered_set import OrderedSet
from .codegen.simd_kernel_features import SIMDKernelFeatures
@ -40,6 +53,80 @@ class InductorChoices:
torch._inductor.virtualized.V.set_choices_handler(MyHeuristics())
"""
def get_config_heuristics(
self, device_type: Optional[str] = "cuda"
) -> BaseConfigHeuristic:
if device_type == "cuda":
if torch.version.hip is None:
return CUDAConfigHeuristic()
else:
return ROCmConfigHeuristic()
elif device_type == "xpu":
return XPUConfigHeuristic()
elif device_type == "cpu":
return CPUConfigHeuristic()
else:
return BaseConfigHeuristic()
# GEMM configs
def get_base_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
if config.max_autotune_gemm_search_space != "EXHAUSTIVE":
return mm_heuristics.get_mm_configs()
else:
return mm_heuristics.get_exhaustive_mm_configs()
def get_extra_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_extra_mm_configs()
def get_int8_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_int8_mm_configs()
def get_mixed_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_mixed_mm_configs()
def get_persistent_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_persistent_mm_configs()
def get_scaled_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_scaled_mm_configs()
def get_scaled_persistent_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_scaled_persistent_mm_configs()
def get_mm_plus_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_mm_plus_mm_configs()
# Conv configs
def get_conv_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
conv_heuristics = self.get_config_heuristics(device_type)
return conv_heuristics.get_conv_configs()
def triton_kernel_kwargs(
self,
kernel_cls: type[TritonKernel],

View File

@ -24,7 +24,7 @@ from .mm_common import (
_is_static_problem,
addmm_epilogue,
mm_args,
mm_configs,
mm_config_kwargs,
mm_options,
should_fallback_to_aten,
)
@ -46,12 +46,6 @@ def _is_large_block_for_cpu(m, n, k):
return m * n > 2**12
def bmm_configs(m, n, k, *, device_type):
if device_type == "cpu":
return mm_configs(m, n, k, scale=0.5, exclude=_is_large_block_for_cpu)
return mm_configs(m, n, k)
bmm_template = TritonTemplate(
name="bmm",
grid=bmm_grid,
@ -184,8 +178,14 @@ def tuned_bmm(mat1, mat2, *, layout=None):
# options to tune from
choices = [aten_bmm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
device_type = ir.get_device_type(mat1)
bmm_configs = V.choices.get_base_mm_configs(device_type)
if use_triton_template(layout):
for config in bmm_configs(m, n, k, device_type=ir.get_device_type(mat1)):
for config in bmm_configs(
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
):
bmm_template.maybe_append_choice(
choices,
input_nodes=(mat1, mat2),
@ -239,8 +239,14 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
if use_aten_gemm_kernels()
else []
)
device_type = ir.get_device_type(mat1)
bmm_configs = V.choices.get_base_mm_configs(device_type)
if use_triton_template(layout):
for config in bmm_configs(m, n, k, device_type=ir.get_device_type(mat1)):
for config in bmm_configs(
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
):
bmm_template.maybe_append_choice(
choices,
input_nodes=(inp, mat1, mat2),

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import logging
from typing import cast, Optional, TYPE_CHECKING, TypedDict
from typing import Optional, TYPE_CHECKING, TypedDict
import torch
from torch._inductor.codegen.rocm.ck_conv_template import CKGroupedConvFwdTemplate
@ -29,7 +29,7 @@ from ..utils import (
use_triton_template,
)
from ..virtualized import V
from .mm_common import build_rocm_gemm_configs, filtered_configs
from .mm_common import mm_config_kwargs
if TYPE_CHECKING:
@ -61,31 +61,6 @@ def conv3d_grid(n, c, d, h, w, meta, *, cdiv):
)
# List of dictionaries to store the kernel configs. Configs that evaluate to true
# will be utilised on the target platform
kernel_configs = [
# "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps"
{"config": (64, 256, 16, 2, 4), "cond": True},
{"config": (256, 64, 16, 2, 4), "cond": True},
{"config": (1024, 16, 16, 1, 8), "cond": True},
{"config": (128, 128, 32, 2, 8), "cond": True},
{"config": (64, 64, 32, 2, 4), "cond": True},
{"config": (64, 256, 32, 2, 8), "cond": True},
{"config": (256, 64, 32, 2, 8), "cond": True},
]
# Create filtered list of configs based on conv
platform_configs = tuple(
cast(tuple[int, int, int, int, int], config["config"])
for config in kernel_configs
if config["cond"]
)
# On ROCm convert num_stages to 1 as pipelining provides no benefit
if torch.version.hip and torch.cuda.is_available():
platform_configs = build_rocm_gemm_configs(platform_configs)
def _is_large_block_for_cpu(m, n, k):
# Thresholds are experimentally determined to reduce Triton CPU compile times
if m > 256 or n > 256 or k > 256:
@ -93,19 +68,6 @@ def _is_large_block_for_cpu(m, n, k):
return m * n * k > 2**17
def conv_configs(m, n, k, *, device_type, **kwargs):
if device_type == "cpu":
return filtered_configs(
m,
n,
k,
configs=platform_configs,
scale=0.5,
exclude=_is_large_block_for_cpu,
)
return filtered_configs(m, n, k, configs=platform_configs)
LOOP_BODY_2D = """
idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H
idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W
@ -497,6 +459,8 @@ def convolution(
"groups": groups,
}
device_type = ir.get_device_type(x)
if len(x.get_size()) == len(weight.get_size()) - 1:
# add batch dimension to simplify rest of function
return L[aten.squeeze](
@ -511,11 +475,7 @@ def convolution(
# Always convert conv1D to 2D for Intel GPU.
# Only conv2D can be converted to channel last layout,
# which have much better performance.
if (
len(x.get_size()) == 3
and len(kernel_shape) == 1
and ir.get_device_type(x) == "xpu"
):
if len(x.get_size()) == 3 and len(kernel_shape) == 1 and device_type == "xpu":
kwargs.update(
{
"stride": (1,) + stride,
@ -564,7 +524,7 @@ def convolution(
):
return convert_1x1_conv_to_mm(x, weight, bias)
if bias is not None and ir.get_device_type(x) != "cpu":
if bias is not None and device_type != "cpu":
# peel off the bias, cudnn is slower with it
result = convolution(x, weight, None, **kwargs)
return L[aten.add](
@ -639,11 +599,13 @@ def convolution(
):
choices.append(aten_conv1x1_via_mm.bind(args, layout))
conv_configs = V.choices.get_conv_configs(device_type)
for cfg in conv_configs(
sympy_product([x.get_size()[0], *x.get_size()[2:]]),
out_chan,
in_chan,
device_type=ir.get_device_type(x),
**mm_config_kwargs(device_type, _is_large_block_for_cpu),
):
if ndim == 2:
conv2d_template.maybe_append_choice(

View File

@ -28,7 +28,6 @@ from ..select_algorithm import (
TritonTemplate,
)
from ..utils import (
get_gpu_shared_memory,
get_tma_workspace_arg,
use_aten_gemm_kernels,
use_ck_gemm_template,
@ -41,17 +40,13 @@ from ..utils import (
from .mm_common import (
_is_static_problem,
addmm_epilogue,
extra_mm_configs,
int8_mm_configs,
mm_args,
mm_configs,
mm_config_kwargs,
mm_grid,
mm_options,
persistent_mm_configs,
persistent_mm_grid,
persistent_mm_options,
should_fallback_to_aten,
triton_config,
)
@ -341,15 +336,6 @@ def _is_large_block_for_cpu(m, n, k):
return m * n > 2**13
def mm_config_kwargs(device):
if device == "cpu":
return {
"scale": 0.5,
"exclude": _is_large_block_for_cpu,
}
return {}
def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1):
"""
Giving torch.addmm a 1D tensor calls a different (faster) cublasLt
@ -367,6 +353,7 @@ aten_bias_addmm = ExternKernelChoice(bias_addmm, None)
@register_lowering(aten.mm, type_promotion_kind=None)
def tuned_mm(mat1, mat2, *, layout=None):
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
device_type = ir.get_device_type(mat1)
name = "mm"
# below is for getting an overview logging info of inductor mms
@ -392,8 +379,15 @@ def tuned_mm(mat1, mat2, *, layout=None):
[aten_mm.bind((mat1, mat2), aten_layout)] if use_aten_gemm_kernels() else []
)
static_shape, is_nonzero = _is_static_problem(layout)
mm_configs = V.choices.get_base_mm_configs(device_type)
persistent_mm_configs = V.choices.get_persistent_mm_configs(device_type)
extra_mm_configs = V.choices.get_extra_mm_configs(device_type)
if is_nonzero and use_triton_template(layout):
for config in mm_configs(m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))):
for config in mm_configs(
m, n, k, *mm_config_kwargs(device_type, _is_large_block_for_cpu)
):
mm_template.maybe_append_choice(
choices,
input_nodes=(mat1, mat2),
@ -402,7 +396,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
)
if use_triton_tma_template(mat1, mat2):
for config in persistent_mm_configs(
m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
):
persistent_tma_mm_template.maybe_append_choice(
choices,
@ -441,7 +435,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
always_included.append("extern_mm")
num_choices_before_extra_configs = len(choices)
for config in extra_mm_configs(
m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
):
mm_template.maybe_append_choice(
choices,
@ -503,6 +497,8 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
layout,
)
device_type = ir.get_device_type(mat1)
static_shape, is_nonzero = _is_static_problem(layout)
use_cutlass = static_shape and is_nonzero and use_cutlass_template(layout, m, n, k)
@ -514,9 +510,12 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True
)
int8_mm_configs = V.choices.get_int8_mm_configs(device_type)
if is_nonzero and use_triton_template(layout, enable_int32=True):
for config in int8_mm_configs(
m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
):
mm_template.maybe_append_choice(
choices,
@ -534,6 +533,7 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
@register_lowering(aten.addmm, type_promotion_kind=None)
def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
ordered_kwargs_for_cpp_kernel = ("beta", "alpha")
device_type = ir.get_device_type(mat1)
m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout)
static_shape, is_nonzero = _is_static_problem(layout)
@ -599,8 +599,13 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
),
)
mm_configs = V.choices.get_base_mm_configs(device_type)
persistent_mm_configs = V.choices.get_persistent_mm_configs(device_type)
if is_nonzero and use_triton_template(layout):
for config in mm_configs(m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))):
for config in mm_configs(
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
):
mm_template.maybe_append_choice(
choices,
input_nodes=(inp_expanded, mat1, mat2),
@ -612,7 +617,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
if use_triton_tma_template(mat1, mat2):
for config in persistent_mm_configs(
m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
):
persistent_tma_mm_template.maybe_append_choice(
choices,
@ -751,52 +756,6 @@ def dims_are_int(dims):
return all(isinstance(dim, int) for dim in dims)
def try_heuristic(m, n, k, choices, mat1, mat2, mat2_dtype, layout):
m, n, k = get_size_hints(mat1, mat2, m, n, k)
if not dims_are_int([m, n, k]):
return None
if mat1.dtype != torch.float16:
return None
# only use heuristic if we are running on an A100
# torch.cuda.get_device_capability() >= (8, 0) returns true for A10G
# which does not have enough shared memory for one of the configs
if (
not torch.cuda.get_device_capability() >= (8, 0)
) or get_gpu_shared_memory() != 166912:
return None
if m == 1 and (n % 16 != 0 or k % 16 != 0):
return None
if m <= 16 and n >= 4096 and k >= 4096:
return triton_config(
BLOCK_M=16,
BLOCK_N=64,
BLOCK_K=128,
num_stages=5,
num_warps=4,
)
elif m > 16 and m <= 32 and n >= 4096 and k >= 4096:
return triton_config(
BLOCK_M=32,
BLOCK_N=32,
BLOCK_K=128,
num_stages=5,
num_warps=4,
)
elif m > 32 and m <= 64 and n >= 4096 and k >= 4096:
return triton_config(
BLOCK_M=64,
BLOCK_N=32,
BLOCK_K=128,
num_stages=5,
num_warps=4,
)
return None
def mm_autoheuristic(
mat1,
mat2,

View File

@ -1,441 +1,22 @@
# mypy: allow-untyped-defs
import functools
import itertools
import logging
from collections.abc import Sequence
from typing import Any, cast
from typing import Any
import sympy
import torch
from torch._inductor.select_algorithm import realize_inputs, SymbolicGridFn
from torch._inductor.virtualized import V
from torch.utils._ordered_set import OrderedSet
from .. import config as inductor_config
from ..codegen.wrapper import PythonWrapperCodegen
from ..ir import ChoiceCaller, Layout
from ..runtime.runtime_utils import next_power_of_2
from ..utils import (
get_backend_num_stages,
get_num_sms,
TMA_DESCRIPTOR_SIZE,
use_aten_gemm_kernels,
)
from ..utils import get_num_sms, TMA_DESCRIPTOR_SIZE, use_aten_gemm_kernels
log = logging.getLogger(__name__)
def triton_config(num_stages, num_warps, **kwargs):
from triton import Config # type: ignore[attr-defined]
return Config(kwargs, num_stages=num_stages, num_warps=num_warps)
def build_rocm_gemm_configs(configs):
rocm_num_stages = get_backend_num_stages()
return tuple((c[0], c[1], c[2], rocm_num_stages, c[4]) for c in configs)
def filtered_configs(
m: int,
n: int,
k: int,
configs: Sequence[tuple[int, int, int, int, int]],
has_int8_tensor=False,
scale=1,
exclude=lambda m, n, k: False,
):
"""
Heuristic to shrink configs when they are bigger than the input size
:param scale: scale factor applied to the config values
:param exclude: whether a given config should be excluded
"""
from torch._inductor import config
max_mm_configs = config.test_configs.max_mm_configs
min_block_size = 16
# block_k=16 seems to be causing issues
# see: https://github.com/triton-lang/triton/issues/2156#issuecomment-1695897424
min_block_size_k = 32 if has_int8_tensor else 16
m = max(
next_power_of_2(
V.graph.sizevars.size_hint(
m,
fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
)
),
min_block_size,
)
n = max(
next_power_of_2(
V.graph.sizevars.size_hint(
n,
fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
)
),
min_block_size,
)
k = max(
next_power_of_2(
V.graph.sizevars.size_hint(
k,
fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
)
),
min_block_size_k,
)
used = OrderedSet[tuple[int, ...]]()
for block_m, block_n, block_k, num_stages, num_warps in configs:
# shrink configs for small sizes
block_m = max(min(int(block_m * scale), m), min_block_size)
block_n = max(min(int(block_n * scale), n), min_block_size)
block_k = max(min(int(block_k * scale), k), min_block_size_k)
if exclude(block_m, block_n, block_k):
continue
# each warp computes 16x16 tile = 256
num_warps = min(num_warps, block_m * block_n // 256)
if torch.version.hip:
kpack = 2
for matrix_instr_nonkdim in [0, 16]:
if matrix_instr_nonkdim != 0 and (
block_m % matrix_instr_nonkdim != 0
or block_n % matrix_instr_nonkdim != 0
):
# block_m and block_n must be a multiple of matrix_instr_nonkdim
continue
if (
block_m,
block_n,
block_k,
num_stages,
num_warps,
matrix_instr_nonkdim,
kpack,
) not in used and (
max_mm_configs is None or len(used) < max_mm_configs
):
used.add(
(
block_m,
block_n,
block_k,
num_stages,
num_warps,
matrix_instr_nonkdim,
kpack,
)
)
yield triton_config(
BLOCK_M=block_m,
BLOCK_N=block_n,
BLOCK_K=block_k,
num_stages=num_stages,
num_warps=num_warps,
matrix_instr_nonkdim=matrix_instr_nonkdim,
kpack=kpack,
)
else:
if (block_m, block_n, block_k, num_stages, num_warps, 0) not in used and (
max_mm_configs is None or len(used) < max_mm_configs
):
used.add((block_m, block_n, block_k, num_stages, num_warps, 0))
yield triton_config(
BLOCK_M=block_m,
BLOCK_N=block_n,
BLOCK_K=block_k,
num_stages=num_stages,
num_warps=num_warps,
)
# List of dictionaries to store the kernel configs. Configs that evaluate to true
# will be utilised on the target platform. The configs are as follows:
# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
mm_kernel_configs = (
[
{"config": (32, 32, 16, 1, 2), "cond": True},
{"config": (32, 32, 128, 2, 4), "cond": True},
{"config": (32, 64, 32, 5, 8), "cond": True},
{"config": (64, 32, 32, 5, 8), "cond": True},
{"config": (64, 32, 128, 5, 4), "cond": True},
{"config": (64, 64, 16, 2, 4), "cond": True},
{"config": (64, 64, 32, 2, 4), "cond": True},
{"config": (64, 64, 64, 3, 8), "cond": True},
{"config": (64, 64, 128, 5, 4), "cond": True},
{"config": (64, 128, 32, 3, 4), "cond": True},
{"config": (64, 128, 32, 4, 8), "cond": True},
{"config": (64, 128, 64, 3, 4), "cond": True},
{"config": (64, 128, 128, 4, 4), "cond": True},
{"config": (128, 64, 32, 3, 4), "cond": True},
{"config": (128, 64, 32, 4, 8), "cond": True},
{"config": (128, 128, 32, 2, 8), "cond": True},
{"config": (128, 128, 32, 3, 4), "cond": True},
{"config": (128, 128, 64, 3, 4), "cond": True},
{"config": (128, 128, 64, 5, 8), "cond": True},
{"config": (128, 256, 64, 3, 8), "cond": True},
]
if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE"
else [
{"config": (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps), "cond": True}
for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
[16, 32, 64, 128, 256], repeat=3
)
for num_stages in [1, 2, 3, 4, 5]
for num_warps in [2, 4, 8]
]
)
# these are only used in tuned_mm when AutoHeuristic is enabled
# the idea is that when AutoHeuristic collects data to learn a heuristic, more configs are autotuned
# when the learned heuristic is used, the learned heuristic reduces the number of configs down to 10
# which saves compilation time (since less configs are autotuned) and potentially increase performance
# because the learned heuristic might predict a config that is not part mm_configs
extra_mm_kernel_configs = [
{"config": (16, 32, 16, 3, 2), "cond": True},
{"config": (16, 32, 32, 4, 2), "cond": True},
{"config": (16, 32, 32, 5, 2), "cond": True},
{"config": (64, 64, 128, 3, 4), "cond": True},
{"config": (128, 64, 32, 2, 2), "cond": True},
{"config": (128, 64, 64, 3, 8), "cond": True},
{"config": (128, 64, 128, 4, 8), "cond": True},
{"config": (128, 128, 32, 4, 4), "cond": True},
{"config": (128, 128, 64, 3, 8), "cond": True},
{"config": (128, 128, 64, 5, 4), "cond": True},
]
int8_mm_kernel_configs = [
{"config": (64, 64, 32, 2, 4), "cond": True},
{"config": (64, 128, 32, 3, 4), "cond": True},
{"config": (128, 64, 32, 3, 4), "cond": True},
{"config": (64, 128, 32, 4, 8), "cond": True},
{"config": (128, 64, 32, 4, 8), "cond": True},
{"config": (64, 32, 32, 5, 8), "cond": True},
{"config": (32, 64, 32, 5, 8), "cond": True},
{"config": (128, 128, 32, 2, 8), "cond": True},
{"config": (64, 64, 64, 3, 8), "cond": True},
# {"config": (32, 32, 128, 2, 4), "cond": True},
# {"config": (64, 64, 16, 2, 4), "cond": True},
# {"config": (32, 32, 16, 1, 2), "cond": True},
{"config": (128, 256, 128, 3, 8), "cond": True},
{"config": (256, 128, 128, 3, 8), "cond": True},
]
# Mixed precision kernel configs for small sizes of m for mm's like (16, 8192) x (8192, 8192).
mixed_mm_kernel_configs_small_m = [
{"config": (16, 128, 256, 3, 4), "cond": True},
{"config": (16, 128, 256, 5, 8), "cond": True},
]
mixed_mm_kernel_configs = (
mm_kernel_configs + mixed_mm_kernel_configs_small_m
if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE"
else mm_kernel_configs
)
persistent_mm_kernel_configs = [
{"config": (128, 256, 64, 3, 8), "cond": True},
{"config": (128, 128, 64, 3, 8), "cond": True},
{"config": (128, 128, 128, 3, 8), "cond": True},
{"config": (128, 128, 128, 3, 4), "cond": True},
{"config": (128, 128, 64, 4, 8), "cond": True},
]
scaled_mm_kernel_configs = [
{"config": (128, 256, 32, 3, 8), "cond": True},
{"config": (256, 128, 32, 3, 8), "cond": True},
{"config": (256, 64, 32, 4, 4), "cond": True},
{"config": (64, 256, 32, 4, 4), "cond": True},
{"config": (128, 128, 32, 4, 4), "cond": True},
{"config": (128, 64, 32, 4, 4), "cond": True},
{"config": (64, 128, 32, 4, 4), "cond": True},
{"config": (128, 32, 32, 4, 4), "cond": True},
{"config": (64, 32, 32, 5, 2), "cond": True},
{"config": (256, 128, 128, 3, 8), "cond": True},
{"config": (256, 64, 128, 4, 4), "cond": True},
{"config": (64, 256, 128, 4, 4), "cond": True},
{"config": (128, 128, 128, 4, 4), "cond": True},
{"config": (128, 64, 64, 4, 4), "cond": True},
{"config": (64, 128, 64, 4, 4), "cond": True},
{"config": (128, 32, 64, 4, 4), "cond": True},
{"config": (64, 32, 64, 5, 2), "cond": True},
{"config": (16, 32, 32, 2, 2), "cond": True},
{"config": (16, 64, 32, 2, 2), "cond": True},
{"config": (16, 128, 32, 2, 4), "cond": True},
{"config": (16, 256, 32, 2, 4), "cond": True},
{"config": (16, 32, 64, 2, 2), "cond": True},
{"config": (16, 64, 64, 2, 2), "cond": True},
{"config": (16, 128, 64, 2, 4), "cond": True},
{"config": (16, 256, 64, 2, 4), "cond": True},
{"config": (32, 32, 32, 2, 2), "cond": True},
{"config": (32, 64, 32, 2, 2), "cond": True},
{"config": (32, 128, 32, 2, 4), "cond": True},
{"config": (32, 256, 32, 2, 4), "cond": True},
{"config": (32, 32, 64, 2, 2), "cond": True},
{"config": (32, 64, 64, 2, 2), "cond": True},
{"config": (32, 128, 64, 2, 4), "cond": True},
{"config": (32, 256, 64, 2, 4), "cond": True},
{"config": (16, 32, 32, 3, 2), "cond": True},
{"config": (16, 64, 32, 3, 2), "cond": True},
{"config": (16, 128, 32, 3, 4), "cond": True},
{"config": (16, 256, 32, 3, 4), "cond": True},
{"config": (16, 32, 64, 3, 2), "cond": True},
{"config": (16, 64, 64, 3, 2), "cond": True},
{"config": (16, 128, 64, 3, 4), "cond": True},
{"config": (16, 256, 64, 3, 4), "cond": True},
{"config": (32, 32, 32, 3, 2), "cond": True},
{"config": (32, 64, 32, 3, 2), "cond": True},
{"config": (32, 128, 32, 3, 4), "cond": True},
{"config": (32, 256, 32, 3, 4), "cond": True},
{"config": (32, 32, 64, 3, 2), "cond": True},
{"config": (32, 64, 64, 3, 2), "cond": True},
{"config": (32, 128, 64, 3, 4), "cond": True},
{"config": (32, 256, 64, 3, 4), "cond": True},
{"config": (16, 32, 32, 4, 2), "cond": True},
{"config": (16, 64, 32, 4, 2), "cond": True},
{"config": (16, 128, 32, 4, 4), "cond": True},
{"config": (16, 256, 32, 4, 4), "cond": True},
{"config": (16, 32, 64, 4, 2), "cond": True},
{"config": (16, 64, 64, 4, 2), "cond": True},
{"config": (16, 128, 64, 4, 4), "cond": True},
{"config": (16, 256, 64, 4, 4), "cond": True},
{"config": (32, 32, 32, 4, 2), "cond": True},
{"config": (32, 64, 32, 4, 2), "cond": True},
{"config": (32, 128, 32, 4, 4), "cond": True},
{"config": (32, 256, 32, 4, 4), "cond": True},
{"config": (32, 32, 64, 4, 2), "cond": True},
{"config": (32, 64, 64, 4, 2), "cond": True},
{"config": (32, 128, 64, 4, 4), "cond": True},
{"config": (32, 256, 64, 4, 4), "cond": True},
{"config": (16, 32, 32, 5, 2), "cond": True},
{"config": (16, 64, 32, 5, 2), "cond": True},
{"config": (16, 128, 32, 5, 4), "cond": True},
{"config": (16, 256, 32, 5, 4), "cond": True},
{"config": (16, 32, 64, 5, 2), "cond": True},
{"config": (16, 64, 64, 5, 2), "cond": True},
{"config": (16, 128, 64, 5, 4), "cond": True},
{"config": (16, 256, 64, 5, 4), "cond": True},
{"config": (32, 32, 32, 5, 2), "cond": True},
{"config": (32, 64, 32, 5, 2), "cond": True},
{"config": (32, 128, 32, 5, 4), "cond": True},
{"config": (32, 256, 32, 5, 4), "cond": True},
{"config": (32, 32, 64, 5, 2), "cond": True},
{"config": (32, 64, 64, 5, 2), "cond": True},
{"config": (32, 128, 64, 5, 4), "cond": True},
{"config": (32, 256, 64, 5, 4), "cond": True},
{"config": (16, 32, 32, 6, 2), "cond": True},
{"config": (16, 64, 32, 6, 2), "cond": True},
{"config": (16, 128, 32, 6, 4), "cond": True},
{"config": (16, 256, 32, 6, 4), "cond": True},
{"config": (16, 32, 64, 6, 2), "cond": True},
{"config": (16, 64, 64, 6, 2), "cond": True},
{"config": (16, 128, 64, 6, 4), "cond": True},
{"config": (16, 256, 64, 6, 4), "cond": True},
{"config": (32, 32, 32, 6, 2), "cond": True},
{"config": (32, 64, 32, 6, 2), "cond": True},
{"config": (32, 128, 32, 6, 4), "cond": True},
{"config": (32, 256, 32, 6, 4), "cond": True},
{"config": (32, 32, 64, 6, 2), "cond": True},
{"config": (32, 64, 64, 6, 2), "cond": True},
{"config": (32, 128, 64, 6, 4), "cond": True},
{"config": (32, 256, 64, 6, 4), "cond": True},
]
scaled_persistent_mm_kernel_configs = [
{"config": (128, 128, 64, 3, 8), "cond": True},
{"config": (128, 128, 128, 3, 8), "cond": True},
{"config": (128, 128, 128, 4, 8), "cond": True},
{"config": (128, 128, 128, 4, 4), "cond": True},
{"config": (128, 128, 128, 3, 4), "cond": True},
{"config": (128, 128, 128, 5, 4), "cond": True},
{"config": (128, 128, 128, 5, 8), "cond": True},
{"config": (128, 128, 128, 6, 8), "cond": True},
{"config": (128, 128, 64, 4, 8), "cond": True},
]
# Create filtered list of configs based on cond evaluation
mm_platform_configs = tuple(
cast(tuple[int, int, int, int, int], config["config"])
for config in mm_kernel_configs
if config["cond"]
)
extra_mm_platform_configs = tuple(
cast(tuple[int, int, int, int, int], config["config"])
for config in extra_mm_kernel_configs
if config["cond"]
)
int8_platform_configs = tuple(
cast(tuple[int, int, int, int, int], config["config"])
for config in int8_mm_kernel_configs
if config["cond"]
)
mixed_mm_platform_configs = tuple(
cast(tuple[int, int, int, int, int], config["config"])
for config in mixed_mm_kernel_configs
if config["cond"]
)
persistent_mm_platform_configs = tuple(
cast(tuple[int, int, int, int, int], config["config"])
for config in persistent_mm_kernel_configs
if config["cond"]
)
scaled_mm_platform_configs = tuple(
cast(tuple[int, int, int, int, int], config["config"])
for config in scaled_mm_kernel_configs
if config["cond"]
)
scaled_persistent_mm_platform_configs = tuple(
cast(tuple[int, int, int, int, int], config["config"])
for config in scaled_persistent_mm_kernel_configs
if config["cond"]
)
# On ROCm convert num_stages to improve performance
if torch.version.hip and torch.cuda.is_available():
mm_platform_configs = build_rocm_gemm_configs(mm_platform_configs)
extra_mm_platform_configs = build_rocm_gemm_configs(extra_mm_platform_configs)
int8_platform_configs = build_rocm_gemm_configs(int8_platform_configs)
mixed_mm_platform_configs = build_rocm_gemm_configs(mixed_mm_platform_configs)
scaled_mm_platform_configs = build_rocm_gemm_configs(scaled_mm_platform_configs)
mm_configs = functools.partial(
filtered_configs,
configs=mm_platform_configs,
)
extra_mm_configs = functools.partial(
filtered_configs,
configs=extra_mm_platform_configs,
)
int8_mm_configs = functools.partial(
filtered_configs,
configs=int8_platform_configs,
)
persistent_mm_configs = functools.partial(
filtered_configs,
configs=persistent_mm_platform_configs,
)
scaled_mm_configs = functools.partial(
filtered_configs,
configs=scaled_mm_platform_configs,
)
scaled_persistent_mm_configs = functools.partial(
filtered_configs,
configs=scaled_persistent_mm_platform_configs,
)
def should_fallback_to_aten(choices: list[ChoiceCaller]) -> bool:
if len(choices) == 0 and not use_aten_gemm_kernels():
if inductor_config.autotune_fallback_to_aten:
@ -553,6 +134,15 @@ def mm_args(
return [m, n, k, layout, mat1, mat2, *others]
def mm_config_kwargs(device, exclude_condition):
if device == "cpu":
return {
"scale": 0.5,
"exclude": exclude_condition,
}
return {}
def addmm_epilogue(dtype, alpha, beta):
def epilogue(acc, bias):
if alpha != 1:

View File

@ -1,8 +1,8 @@
# mypy: allow-untyped-defs
import functools
import torch
from .. import ir
from ..lowering import lowerings
from ..select_algorithm import (
autotune_select_algorithm,
@ -112,101 +112,14 @@ mm_plus_mm_template = TritonTemplate(
)
@functools.lru_cache(None)
def mm_configs():
import triton
# List of dictionaries to store the kernel configs. Configs that evaluate to true
# will be utilised on the target platform
mm_triton_configs = [
{
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
"num_stages": 2,
"num_warps": 4,
"cond": True,
},
{
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
"num_stages": 3,
"num_warps": 8,
"cond": True,
},
{
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
"num_stages": 4,
"num_warps": 16,
"cond": True,
},
{
"config": {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32},
"num_stages": 4,
"num_warps": 8,
"cond": True,
},
{
"config": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32},
"num_stages": 4,
"num_warps": 8,
"cond": True,
},
{
"config": {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32},
"num_stages": 1,
"num_warps": 8,
"cond": True,
},
{
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64},
"num_stages": 1,
"num_warps": 8,
"cond": True,
},
{
"config": {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 128},
"num_stages": 1,
"num_warps": 8,
"cond": torch.version.hip is None,
},
{
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 16},
"num_stages": 2,
"num_warps": 4,
"cond": True,
},
{
"config": {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 16},
"num_stages": 1,
"num_warps": 2,
"cond": True,
},
]
# Filter out configs in which cond evaluates to true
# On ROCm convert num_stages to 1 as pipelining provides no benefit
if torch.version.hip:
filtered_configs = [
triton.Config(c["config"], num_stages=1, num_warps=c["num_warps"])
for c in mm_triton_configs
if c["cond"]
]
else:
filtered_configs = [
triton.Config(
c["config"], num_stages=c["num_stages"], num_warps=c["num_warps"]
)
for c in mm_triton_configs
if c["cond"]
]
return filtered_configs
def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
"""
Computes mm(mat1, mat2) + mm(mat3, mat4)
"""
m1, n1, k1, layout1, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
m2, n2, _, layout2, mat3, mat4 = mm_args(mat3, mat4, layout=layout)
device_type = ir.get_device_type(mat1)
# Optimization is optional, because we can always just not do the fusion
if (
m1 * n1 == 0
@ -231,6 +144,9 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
if use_aten_gemm_kernels()
else []
)
mm_configs = V.choices.get_mm_plus_mm_configs(device_type)
if use_triton_template(layout1):
for config in mm_configs():
# see https://github.com/openai/triton/issues/1298

View File

@ -11,7 +11,7 @@ from torch._inductor.codegen.rocm.ck_universal_gemm_template import CKGemmTempla
from torch.utils._triton import has_triton_tma_device
from ..config import triton as triton_config
from ..ir import _IntLike, ChoiceCaller, Layout, StorageBox, TensorBox
from ..ir import _IntLike, ChoiceCaller, get_device_type, Layout, StorageBox, TensorBox
from ..lowering import add_layout_constraint, constrain_to_fx_strides, register_lowering
from ..select_algorithm import (
autotune_select_algorithm,
@ -27,13 +27,12 @@ from ..utils import (
use_ck_gemm_template,
use_triton_template,
)
from ..virtualized import V
from .mm_common import (
_is_static_problem,
mm_args,
mm_grid,
persistent_mm_grid,
scaled_mm_configs,
scaled_persistent_mm_configs,
should_fallback_to_aten,
)
@ -508,6 +507,7 @@ def tuned_scaled_mm(
m, n, k, layout, mat_a, mat_b = mm_args(
mat_a, mat_b, layout=layout, out_dtype=out_dtype
)
# below is for getting an overview logging info of inductor mms
counters["aten_mm_info"][f"aten._scaled_mm.default_{m}_{n}_{k}"] += 1
log.info(
@ -520,6 +520,8 @@ def tuned_scaled_mm(
layout,
)
device_type = get_device_type(mat_a)
check_supported_striding(mat_a, mat_b)
scale_a, scale_b = realize_inputs(scale_a, scale_b)
@ -544,6 +546,11 @@ def tuned_scaled_mm(
_, is_nonzero = _is_static_problem(layout)
scaled_mm_configs = V.choices.get_scaled_mm_configs(device_type)
scaled_persistent_mm_configs = V.choices.get_scaled_persistent_mm_configs(
device_type
)
if is_nonzero and use_triton_template(layout, enable_float8=True):
if use_persistent_tma(k, bias is not None):
for config in scaled_persistent_mm_configs(m, n, k):

View File

@ -44,7 +44,8 @@ def set_driver_to_gpu():
def get_backend_options():
driver = triton.runtime.driver
from triton.runtime import driver
target = driver.active.get_current_target()
backend = triton.compiler.compiler.make_backend(target)
options = backend.parse_options(dict())

View File

@ -0,0 +1,571 @@
from __future__ import annotations
import itertools
from collections import namedtuple
from functools import partial
from threading import Lock
from typing import Any, Callable, TYPE_CHECKING
from torch.utils._ordered_set import OrderedSet
from . import config
from .utils import get_backend_num_stages
from .virtualized import V
if TYPE_CHECKING:
from collections.abc import Generator, Sequence
from triton import Config as TritonConfig
class BaseConfigSingleton(type):
"""
Thread-safe implementation of single to be used in the config heuristic subclasses
to ensure heavy __init__ calls are not repeatedly run
"""
_instances: dict[type[Any], Any] = {}
_lock: Lock = Lock()
def __call__(
cls: BaseConfigSingleton, *args: Any, **kwargs: Any
) -> BaseConfigHeuristic:
with cls._lock:
if cls not in cls._instances:
instance = super().__call__()
cls._instances[cls] = instance
return cls._instances[cls]
Config = namedtuple(
"Config", ["block_m", "block_n", "block_k", "num_stages", "num_warps"]
)
class BaseConfigHeuristic(metaclass=BaseConfigSingleton):
"""
Base class for mm_configs, device specific triton kernels config inherit from here
"""
def __init__(self) -> None:
# List of dictionaries to store the kernel configs. Configs that evaluate to true
# will be utilised on the target platform. The configs are as follows:
# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
self.mm_configs = [
Config(32, 32, 16, 1, 2),
Config(32, 32, 128, 2, 4),
Config(32, 64, 32, 5, 8),
Config(64, 32, 32, 5, 8),
Config(64, 32, 128, 5, 4),
Config(64, 64, 16, 2, 4),
Config(64, 64, 32, 2, 4),
Config(64, 64, 64, 3, 8),
Config(64, 64, 128, 5, 4),
Config(64, 128, 32, 3, 4),
Config(64, 128, 32, 4, 8),
Config(64, 128, 64, 3, 4),
Config(64, 128, 128, 4, 4),
Config(128, 64, 32, 3, 4),
Config(128, 64, 32, 4, 8),
Config(128, 128, 32, 2, 8),
Config(128, 128, 32, 3, 4),
Config(128, 128, 64, 3, 4),
Config(128, 128, 64, 5, 8),
]
# Exhaustive search for mm configs
self.exhaustive_configs = [
Config(BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
[16, 32, 64, 128, 256], repeat=3
)
for num_stages in [1, 2, 3, 4, 5]
for num_warps in [2, 4, 8]
]
# these are only used in tuned_mm when AutoHeuristic is enabled
# the idea is that when AutoHeuristic collects data to learn a heuristic, more configs are autotuned
# when the learned heuristic is used, the learned heuristic reduces the number of configs down to 10
# which saves compilation time (since less configs are autotuned) and potentially increase performance
# because the learned heuristic might predict a config that is not part mm_configs
self.extra_mm_configs = [
Config(16, 32, 16, 3, 2),
Config(16, 32, 32, 4, 2),
Config(16, 32, 32, 5, 2),
Config(64, 64, 128, 3, 4),
Config(128, 64, 32, 2, 2),
Config(128, 64, 64, 3, 8),
Config(128, 64, 128, 4, 8),
Config(128, 128, 32, 4, 4),
Config(128, 128, 64, 3, 8),
Config(128, 128, 64, 5, 4),
]
self.int8_mm_configs = [
Config(64, 64, 32, 2, 4),
Config(64, 128, 32, 3, 4),
Config(128, 64, 32, 3, 4),
Config(64, 128, 32, 4, 8),
Config(128, 64, 32, 4, 8),
Config(64, 32, 32, 5, 8),
Config(32, 64, 32, 5, 8),
Config(128, 128, 32, 2, 8),
Config(64, 64, 64, 3, 8),
Config(128, 256, 128, 3, 8),
Config(256, 128, 128, 3, 8),
]
self.mixed_mm_configs = [
Config(16, 128, 256, 3, 4),
Config(16, 128, 256, 5, 8),
]
self.persistent_mm_configs = [
Config(128, 256, 64, 3, 8),
Config(128, 128, 64, 3, 8),
Config(128, 128, 128, 3, 8),
Config(128, 128, 128, 3, 4),
Config(128, 128, 64, 4, 8),
]
self.scaled_mm_configs = [
Config(128, 256, 32, 3, 8),
Config(256, 128, 32, 3, 8),
Config(256, 64, 32, 4, 4),
Config(64, 256, 32, 4, 4),
Config(128, 128, 32, 4, 4),
Config(128, 64, 32, 4, 4),
Config(64, 128, 32, 4, 4),
Config(128, 32, 32, 4, 4),
Config(64, 32, 32, 5, 2),
Config(256, 128, 128, 3, 8),
Config(256, 64, 128, 4, 4),
Config(64, 256, 128, 4, 4),
Config(128, 128, 128, 4, 4),
Config(128, 64, 64, 4, 4),
Config(64, 128, 64, 4, 4),
Config(128, 32, 64, 4, 4),
Config(64, 32, 64, 5, 2),
Config(16, 32, 32, 2, 2),
Config(16, 64, 32, 2, 2),
Config(16, 128, 32, 2, 4),
Config(16, 256, 32, 2, 4),
Config(16, 32, 64, 2, 2),
Config(16, 64, 64, 2, 2),
Config(16, 128, 64, 2, 4),
Config(16, 256, 64, 2, 4),
Config(32, 32, 32, 2, 2),
Config(32, 64, 32, 2, 2),
Config(32, 128, 32, 2, 4),
Config(32, 256, 32, 2, 4),
Config(32, 32, 64, 2, 2),
Config(32, 64, 64, 2, 2),
Config(32, 128, 64, 2, 4),
Config(32, 256, 64, 2, 4),
Config(16, 32, 32, 3, 2),
Config(16, 64, 32, 3, 2),
Config(16, 128, 32, 3, 4),
Config(16, 256, 32, 3, 4),
Config(16, 32, 64, 3, 2),
Config(16, 64, 64, 3, 2),
Config(16, 128, 64, 3, 4),
Config(16, 256, 64, 3, 4),
Config(32, 32, 32, 3, 2),
Config(32, 64, 32, 3, 2),
Config(32, 128, 32, 3, 4),
Config(32, 256, 32, 3, 4),
Config(32, 32, 64, 3, 2),
Config(32, 64, 64, 3, 2),
Config(32, 128, 64, 3, 4),
Config(32, 256, 64, 3, 4),
Config(16, 32, 32, 4, 2),
Config(16, 64, 32, 4, 2),
Config(16, 128, 32, 4, 4),
Config(16, 256, 32, 4, 4),
Config(16, 32, 64, 4, 2),
Config(16, 64, 64, 4, 2),
Config(16, 128, 64, 4, 4),
Config(16, 256, 64, 4, 4),
Config(32, 32, 32, 4, 2),
Config(32, 64, 32, 4, 2),
Config(32, 128, 32, 4, 4),
Config(32, 256, 32, 4, 4),
Config(32, 32, 64, 4, 2),
Config(32, 64, 64, 4, 2),
Config(32, 128, 64, 4, 4),
Config(32, 256, 64, 4, 4),
Config(16, 32, 32, 5, 2),
Config(16, 64, 32, 5, 2),
Config(16, 128, 32, 5, 4),
Config(16, 256, 32, 5, 4),
Config(16, 32, 64, 5, 2),
Config(16, 64, 64, 5, 2),
Config(16, 128, 64, 5, 4),
Config(16, 256, 64, 5, 4),
Config(32, 32, 32, 5, 2),
Config(32, 64, 32, 5, 2),
Config(32, 128, 32, 5, 4),
Config(32, 256, 32, 5, 4),
Config(32, 32, 64, 5, 2),
Config(32, 64, 64, 5, 2),
Config(32, 128, 64, 5, 4),
Config(32, 256, 64, 5, 4),
Config(16, 32, 32, 6, 2),
Config(16, 64, 32, 6, 2),
Config(16, 128, 32, 6, 4),
Config(16, 256, 32, 6, 4),
Config(16, 32, 64, 6, 2),
Config(16, 64, 64, 6, 2),
Config(16, 128, 64, 6, 4),
Config(16, 256, 64, 6, 4),
Config(32, 32, 32, 6, 2),
Config(32, 64, 32, 6, 2),
Config(32, 128, 32, 6, 4),
Config(32, 256, 32, 6, 4),
Config(32, 32, 64, 6, 2),
Config(32, 64, 64, 6, 2),
Config(32, 128, 64, 6, 4),
Config(32, 256, 64, 6, 4),
]
self.scaled_persistent_mm_configs = [
Config(128, 128, 64, 3, 8),
Config(128, 128, 128, 3, 8),
Config(128, 128, 128, 4, 8),
Config(128, 128, 128, 4, 4),
Config(128, 128, 128, 3, 4),
Config(128, 128, 128, 5, 4),
Config(128, 128, 128, 5, 8),
Config(128, 128, 128, 6, 8),
Config(128, 128, 64, 4, 8),
]
# TODO: Unify with other gemm patterns, mm_plus_mm currently follows
# slightly different pattern than rest
self.mm_plus_mm_configs = [
Config(64, 64, 32, 2, 4),
Config(64, 64, 32, 3, 8),
Config(64, 64, 32, 4, 16),
Config(64, 32, 32, 4, 8),
Config(32, 64, 32, 4, 8),
Config(128, 128, 32, 1, 8),
Config(64, 64, 64, 1, 8),
Config(32, 32, 128, 1, 8),
Config(64, 64, 16, 2, 4),
Config(32, 32, 16, 1, 2),
]
self.conv_configs = [
Config(64, 256, 16, 2, 4),
Config(256, 64, 16, 2, 4),
Config(1024, 16, 16, 1, 8),
Config(128, 128, 32, 2, 8),
Config(64, 64, 32, 2, 4),
Config(64, 256, 32, 2, 8),
Config(256, 64, 32, 2, 8),
]
def _finalize_mm_configs(
self,
configs: list[Config],
) -> Generator[TritonConfig, None, None]:
"""
Finalizes configs after scaling, applying additional constraints.
"""
used = OrderedSet[Config]()
max_mm_configs = config.test_configs.max_mm_configs
for block_m, block_n, block_k, num_stages, num_warps in configs:
# Each warp computes a 16x16 tile = 256 elements
num_warps = min(num_warps, block_m * block_n // 256)
if (
Config(block_m, block_n, block_k, num_stages, num_warps)
) not in used and (max_mm_configs is None or len(used) < max_mm_configs):
used.add(Config(block_m, block_n, block_k, num_stages, num_warps))
yield self.triton_config(
BLOCK_M=block_m,
BLOCK_N=block_n,
BLOCK_K=block_k,
num_stages=num_stages,
num_warps=num_warps,
)
def _scale_mm_configs(
self,
m: int,
n: int,
k: int,
configs: Sequence[Config],
scale: float,
has_int8_tensor: bool,
exclude: Callable[[int, int, int], bool],
) -> list[Config]:
"""
Scales and filters matrix multiplication configs based on input size.
"""
from .runtime.runtime_utils import next_power_of_2
min_block_size = 16
min_block_size_k = 32 if has_int8_tensor else 16
m = max(
next_power_of_2(
V.graph.sizevars.size_hint(
m,
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
)
),
min_block_size,
)
n = max(
next_power_of_2(
V.graph.sizevars.size_hint(
n,
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
)
),
min_block_size,
)
k = max(
next_power_of_2(
V.graph.sizevars.size_hint(
k,
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
)
),
min_block_size_k,
)
scaled_configs = []
for c in configs:
scaled_config = c._replace(
block_m=max(min(int(c.block_m * scale), m), min_block_size),
block_n=max(min(int(c.block_n * scale), n), min_block_size),
block_k=max(min(int(c.block_k * scale), k), min_block_size_k),
)
if not exclude(
scaled_config.block_m, scaled_config.block_n, scaled_config.block_k
):
scaled_configs.append(scaled_config)
return scaled_configs
def preprocess_mm_configs(
self,
m: int,
n: int,
k: int,
configs: Sequence[Config],
has_int8_tensor: bool = False,
scale: int = 1,
exclude: Callable[[int, int, int], bool] = lambda m, n, k: False,
) -> Generator[TritonConfig, None, None]:
scaled_configs = self._scale_mm_configs(
m, n, k, configs, scale, has_int8_tensor, exclude
)
return self._finalize_mm_configs(scaled_configs)
def triton_config(
self, num_stages: int, num_warps: int, **kwargs: Any
) -> TritonConfig:
from triton import Config as TritonConfig # type: ignore[attr-defined]
return TritonConfig(kwargs, num_stages=num_stages, num_warps=num_warps)
def get_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self.preprocess_mm_configs, configs=self.mm_configs)
def get_exhaustive_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self.preprocess_mm_configs, configs=self.exhaustive_configs)
def get_extra_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self.preprocess_mm_configs, configs=self.extra_mm_configs)
def get_int8_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self.preprocess_mm_configs, configs=self.int8_mm_configs)
def get_mixed_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
mm_configs = (
self.mm_configs + self.mixed_mm_configs
if config.max_autotune_gemm_search_space == "EXHAUSTIVE"
else self.mm_configs
)
return partial(self.preprocess_mm_configs, configs=mm_configs)
def get_persistent_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self.preprocess_mm_configs, configs=self.persistent_mm_configs)
def get_scaled_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self.preprocess_mm_configs, configs=self.scaled_mm_configs)
def get_scaled_persistent_mm_configs(
self,
) -> partial[Generator[TritonConfig, None, None]]:
return partial(
self.preprocess_mm_configs, configs=self.scaled_persistent_mm_configs
)
def get_mm_plus_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self._finalize_mm_configs, configs=self.mm_plus_mm_configs)
def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self.preprocess_mm_configs, configs=self.conv_configs)
class CPUConfigHeuristic(BaseConfigHeuristic):
pass
class CUDAConfigHeuristic(BaseConfigHeuristic):
pass
class ROCmConfigHeuristic(BaseConfigHeuristic):
def __init__(self) -> None:
super().__init__()
self.default_num_stages = get_backend_num_stages()
# Exhaustive search for mm configs
self.exhaustive_configs = [
Config(BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
[16, 32, 64, 128, 256], repeat=3
)
for num_stages in [1, self.default_num_stages]
for num_warps in [4, 8]
]
def _filter_configs(
self, configs: list[Config], new_num_stages: int
) -> list[Config]:
filtered_configs = [
c._replace(num_stages=self.default_num_stages) for c in configs
]
return filtered_configs
def _finalize_mm_configs(
self,
configs: list[Config],
) -> Generator[TritonConfig, None, None]:
used = OrderedSet[tuple[Config, int, int]]()
max_mm_configs = config.test_configs.max_mm_configs
for block_m, block_n, block_k, num_stages, num_warps in configs:
# each warp computes 16x16 tile = 256
num_warps = min(num_warps, block_m * block_n // 256)
kpack = 2
for matrix_instr_nonkdim in [0, 16]:
if matrix_instr_nonkdim != 0 and (
block_m % matrix_instr_nonkdim != 0
or block_n % matrix_instr_nonkdim != 0
):
# block_m and block_n must be a multiple of matrix_instr_nonkdim
continue
if (
Config(
block_m,
block_n,
block_k,
num_stages,
num_warps,
),
matrix_instr_nonkdim,
kpack,
) not in used and (
max_mm_configs is None or len(used) < max_mm_configs
):
used.add(
(
Config(
block_m,
block_n,
block_k,
num_stages,
num_warps,
),
matrix_instr_nonkdim,
kpack,
)
)
yield self.triton_config(
BLOCK_M=block_m,
BLOCK_N=block_n,
BLOCK_K=block_k,
num_stages=num_stages,
num_warps=num_warps,
matrix_instr_nonkdim=matrix_instr_nonkdim,
kpack=kpack,
)
def get_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.mm_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_exhaustive_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.exhaustive_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_extra_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.extra_mm_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_int8_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.int8_mm_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_mixed_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
mm_configs = (
self.mm_configs + self.mixed_mm_configs
if config.max_autotune_gemm_search_space == "EXHAUSTIVE"
else self.mm_configs
)
filtered_configs = self._filter_configs(mm_configs, self.default_num_stages)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_persistent_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.persistent_mm_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_scaled_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.scaled_mm_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_scaled_persistent_mm_configs(
self,
) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.scaled_persistent_mm_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_mm_plus_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(self.mm_plus_mm_configs, 1)
return partial(self._finalize_mm_configs, configs=filtered_configs)
def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.conv_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
class XPUConfigHeuristic(BaseConfigHeuristic):
pass