Revert "[ROCm] ROCm-specific gemm tuning parameters" (#147388)

Summary:
This diff reverts D69573225 / https://github.com/pytorch/pytorch/pull/143286

15% cold compile time regression, see https://fb.workplace.com/groups/1075192433118967/permalink/1608559059782299/

Test Plan: NA

Differential Revision: D69790102

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147388
Approved by: https://github.com/yanboliang
This commit is contained in:
Jason Ansel 2025-02-19 04:47:35 +00:00 committed by PyTorch MergeBot
parent 4ece056791
commit 465930ee81
10 changed files with 629 additions and 802 deletions

View File

@ -1,32 +1,20 @@
from __future__ import annotations
import typing
from typing import Any, Generator, Optional, TYPE_CHECKING
from typing import Any, TYPE_CHECKING
import sympy
import torch
from . import config
from .codecache import write_text
from .metrics import get_metric_table, is_metric_table_enabled
from .runtime.hints import DeviceProperties, ReductionHint
from .scheduler import BaseSchedulerNode, Scheduler, WhyNoFuse
from .template_heuristics import (
BaseConfigHeuristic,
CPUConfigHeuristic,
CUDAConfigHeuristic,
ROCmConfigHeuristic,
XPUConfigHeuristic,
)
from .virtualized import V
if TYPE_CHECKING:
from functools import partial
from triton import Config as TritonConfig
import torch
from torch.utils._ordered_set import OrderedSet
from .codegen.simd_kernel_features import SIMDKernelFeatures
@ -53,80 +41,6 @@ class InductorChoices:
torch._inductor.virtualized.V.set_choices_handler(MyHeuristics())
"""
def get_config_heuristics(
self, device_type: Optional[str] = "cuda"
) -> BaseConfigHeuristic:
if device_type == "cuda":
if torch.version.hip is None:
return CUDAConfigHeuristic()
else:
return ROCmConfigHeuristic()
elif device_type == "xpu":
return XPUConfigHeuristic()
elif device_type == "cpu":
return CPUConfigHeuristic()
else:
return BaseConfigHeuristic()
# GEMM configs
def get_base_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
if config.max_autotune_gemm_search_space != "EXHAUSTIVE":
return mm_heuristics.get_mm_configs()
else:
return mm_heuristics.get_exhaustive_mm_configs()
def get_extra_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_extra_mm_configs()
def get_int8_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_int8_mm_configs()
def get_mixed_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_mixed_mm_configs()
def get_persistent_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_persistent_mm_configs()
def get_scaled_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_scaled_mm_configs()
def get_scaled_persistent_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_scaled_persistent_mm_configs()
def get_mm_plus_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_mm_plus_mm_configs()
# Conv configs
def get_conv_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
conv_heuristics = self.get_config_heuristics(device_type)
return conv_heuristics.get_conv_configs()
def triton_kernel_kwargs(
self,
kernel_cls: type[TritonKernel],

View File

@ -23,7 +23,7 @@ from .mm_common import (
_is_static_problem,
addmm_epilogue,
mm_args,
mm_config_kwargs,
mm_configs,
mm_options,
)
@ -43,6 +43,12 @@ def _is_large_block_for_cpu(m, n, k):
return m * n > 2**12
def bmm_configs(m, n, k, *, device_type):
if device_type == "cpu":
return mm_configs(m, n, k, scale=0.5, exclude=_is_large_block_for_cpu)
return mm_configs(m, n, k)
bmm_template = TritonTemplate(
name="bmm",
grid=bmm_grid,
@ -163,14 +169,8 @@ def tuned_bmm(mat1, mat2, *, layout=None):
# options to tune from
choices = [aten_bmm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
device_type = ir.get_device_type(mat1)
bmm_configs = V.choices.get_base_mm_configs(device_type)
if use_triton_template(layout):
for config in bmm_configs(
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
):
for config in bmm_configs(m, n, k, device_type=ir.get_device_type(mat1)):
bmm_template.maybe_append_choice(
choices,
input_nodes=(mat1, mat2),
@ -212,14 +212,8 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
if use_aten_gemm_kernels()
else []
)
device_type = ir.get_device_type(mat1)
bmm_configs = V.choices.get_base_mm_configs(device_type)
if use_triton_template(layout):
for config in bmm_configs(
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
):
for config in bmm_configs(m, n, k, device_type=ir.get_device_type(mat1)):
bmm_template.maybe_append_choice(
choices,
input_nodes=(inp, mat1, mat2),

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import logging
from typing import Optional, TYPE_CHECKING, TypedDict
from typing import cast, Optional, TYPE_CHECKING, TypedDict
import torch
from torch._inductor.codegen.rocm.ck_conv_template import CKGroupedConvFwdTemplate
@ -29,7 +29,7 @@ from ..utils import (
use_triton_template,
)
from ..virtualized import V
from .mm_common import mm_config_kwargs
from .mm_common import build_rocm_gemm_configs, filtered_configs
if TYPE_CHECKING:
@ -59,6 +59,31 @@ def conv3d_grid(n, c, d, h, w, meta):
)
# List of dictionaries to store the kernel configs. Configs that evaluate to true
# will be utilised on the target platform
kernel_configs = [
# "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps"
{"config": (64, 256, 16, 2, 4), "cond": True},
{"config": (256, 64, 16, 2, 4), "cond": True},
{"config": (1024, 16, 16, 1, 8), "cond": True},
{"config": (128, 128, 32, 2, 8), "cond": True},
{"config": (64, 64, 32, 2, 4), "cond": True},
{"config": (64, 256, 32, 2, 8), "cond": True},
{"config": (256, 64, 32, 2, 8), "cond": True},
]
# Create filtered list of configs based on conv
platform_configs = tuple(
cast(tuple[int, int, int, int, int], config["config"])
for config in kernel_configs
if config["cond"]
)
# On ROCm convert num_stages to 1 as pipelining provides no benefit
if torch.version.hip and torch.cuda.is_available():
platform_configs = build_rocm_gemm_configs(platform_configs)
def _is_large_block_for_cpu(m, n, k):
# Thresholds are experimentally determined to reduce Triton CPU compile times
if m > 256 or n > 256 or k > 256:
@ -66,6 +91,19 @@ def _is_large_block_for_cpu(m, n, k):
return m * n * k > 2**17
def conv_configs(m, n, k, *, device_type, **kwargs):
if device_type == "cpu":
return filtered_configs(
m,
n,
k,
configs=platform_configs,
scale=0.5,
exclude=_is_large_block_for_cpu,
)
return filtered_configs(m, n, k, configs=platform_configs)
LOOP_BODY_2D = """
idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H
idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W
@ -457,8 +495,6 @@ def convolution(
"groups": groups,
}
device_type = ir.get_device_type(x)
if len(x.get_size()) == len(weight.get_size()) - 1:
# add batch dimension to simplify rest of function
return L[aten.squeeze](
@ -473,7 +509,11 @@ def convolution(
# Always convert conv1D to 2D for Intel GPU.
# Only conv2D can be converted to channel last layout,
# which have much better performance.
if len(x.get_size()) == 3 and len(kernel_shape) == 1 and device_type == "xpu":
if (
len(x.get_size()) == 3
and len(kernel_shape) == 1
and ir.get_device_type(x) == "xpu"
):
kwargs.update(
{
"stride": (1,) + stride,
@ -522,7 +562,7 @@ def convolution(
):
return convert_1x1_conv_to_mm(x, weight, bias)
if bias is not None and device_type != "cpu":
if bias is not None and ir.get_device_type(x) != "cpu":
# peel off the bias, cudnn is slower with it
result = convolution(x, weight, None, **kwargs)
return L[aten.add](
@ -597,13 +637,11 @@ def convolution(
):
choices.append(aten_conv1x1_via_mm.bind(args, layout))
conv_configs = V.choices.get_conv_configs(device_type)
for cfg in conv_configs(
sympy_product([x.get_size()[0], *x.get_size()[2:]]),
out_chan,
in_chan,
**mm_config_kwargs(device_type, _is_large_block_for_cpu),
device_type=ir.get_device_type(x),
):
if ndim == 2:
conv2d_template.maybe_append_choice(

View File

@ -43,12 +43,17 @@ from ..utils import (
from .mm_common import (
_is_static_problem,
addmm_epilogue,
extra_mm_configs,
int8_mm_configs,
mixed_mm_configs,
mm_args,
mm_config_kwargs,
mm_configs,
mm_grid,
mm_options,
persistent_mm_configs,
persistent_mm_grid,
persistent_mm_options,
triton_config,
)
@ -335,6 +340,15 @@ def _is_large_block_for_cpu(m, n, k):
return m * n > 2**13
def mm_config_kwargs(device):
if device == "cpu":
return {
"scale": 0.5,
"exclude": _is_large_block_for_cpu,
}
return {}
def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1):
"""
Giving torch.addmm a 1D tensor calls a different (faster) cublasLt
@ -352,7 +366,6 @@ aten_bias_addmm = ExternKernelChoice(bias_addmm, None)
@register_lowering(aten.mm, type_promotion_kind=None)
def tuned_mm(mat1, mat2, *, layout=None):
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
device_type = ir.get_device_type(mat1)
name = "mm"
aten_layout = layout
@ -366,15 +379,8 @@ def tuned_mm(mat1, mat2, *, layout=None):
[aten_mm.bind((mat1, mat2), aten_layout)] if use_aten_gemm_kernels() else []
)
static_shape, is_nonzero = _is_static_problem(layout)
mm_configs = V.choices.get_base_mm_configs(device_type)
persistent_mm_configs = V.choices.get_persistent_mm_configs(device_type)
extra_mm_configs = V.choices.get_extra_mm_configs(device_type)
if is_nonzero and use_triton_template(layout):
for config in mm_configs(
m, n, k, *mm_config_kwargs(device_type, _is_large_block_for_cpu)
):
for config in mm_configs(m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))):
mm_template.maybe_append_choice(
choices,
input_nodes=(mat1, mat2),
@ -383,7 +389,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
)
if use_triton_tma_template(mat1, mat2):
for config in persistent_mm_configs(
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))
):
persistent_tma_mm_template.maybe_append_choice(
choices,
@ -422,7 +428,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
always_included.append("extern_mm")
num_choices_before_extra_configs = len(choices)
for config in extra_mm_configs(
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))
):
mm_template.maybe_append_choice(
choices,
@ -482,7 +488,6 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
m, n, k, layout, mat1, mat2 = mm_args(
mat1, mat2, layout=layout, out_dtype=torch.int32
)
device_type = ir.get_device_type(mat1)
static_shape, is_nonzero = _is_static_problem(layout)
use_cutlass = static_shape and is_nonzero and use_cutlass_template(layout, m, n, k)
@ -498,12 +503,9 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True
)
int8_mm_configs = V.choices.get_int8_mm_configs(device_type)
if is_nonzero and use_triton_template(layout, enable_int32=True):
for config in int8_mm_configs(
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))
):
mm_template.maybe_append_choice(
choices,
@ -530,7 +532,6 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
@register_lowering(aten.addmm, type_promotion_kind=None)
def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
ordered_kwargs_for_cpp_kernel = ("beta", "alpha")
device_type = ir.get_device_type(mat1)
m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout)
static_shape, is_nonzero = _is_static_problem(layout)
if (not is_nonzero) or (not use_max_autotune()):
@ -583,13 +584,8 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
),
)
mm_configs = V.choices.get_base_mm_configs(device_type)
persistent_mm_configs = V.choices.get_persistent_mm_configs(device_type)
if is_nonzero and use_triton_template(layout):
for config in mm_configs(
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
):
for config in mm_configs(m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))):
mm_template.maybe_append_choice(
choices,
input_nodes=(inp_expanded, mat1, mat2),
@ -601,7 +597,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
if use_triton_tma_template(mat1, mat2):
for config in persistent_mm_configs(
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))
):
persistent_tma_mm_template.maybe_append_choice(
choices,
@ -764,7 +760,7 @@ def dims_are_int(dims):
return all(isinstance(dim, int) for dim in dims)
def try_heuristic(m, n, k, choices, mat1, mat2, mat2_dtype, layout, mm_heuristic):
def try_heuristic(m, n, k, choices, mat1, mat2, mat2_dtype, layout):
m, n, k = get_size_hints(mat1, mat2, m, n, k)
if not dims_are_int([m, n, k]):
return None
@ -779,10 +775,35 @@ def try_heuristic(m, n, k, choices, mat1, mat2, mat2_dtype, layout, mm_heuristic
not torch.cuda.get_device_capability() >= (8, 0)
) or get_gpu_shared_memory() != 166912:
return None
elif m == 1 and (n % 16 != 0 or k % 16 != 0):
if m == 1 and (n % 16 != 0 or k % 16 != 0):
return None
else:
return mm_heuristic.generate_mixed_mm_config()
if m <= 16 and n >= 4096 and k >= 4096:
return triton_config(
BLOCK_M=16,
BLOCK_N=64,
BLOCK_K=128,
num_stages=5,
num_warps=4,
)
elif m > 16 and m <= 32 and n >= 4096 and k >= 4096:
return triton_config(
BLOCK_M=32,
BLOCK_N=32,
BLOCK_K=128,
num_stages=5,
num_warps=4,
)
elif m > 32 and m <= 64 and n >= 4096 and k >= 4096:
return triton_config(
BLOCK_M=64,
BLOCK_N=32,
BLOCK_K=128,
num_stages=5,
num_warps=4,
)
return None
def mm_autoheuristic(
@ -879,7 +900,6 @@ def get_size_hints_strides(mat1, mat2):
def tuned_mixed_mm(mat1, mat2, mat2_dtype):
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=None)
device_type = ir.get_device_type(mat1)
static_shape, is_nonzero = _is_static_problem(layout)
fallback = aten_fallback_mixed_mm.bind((mat1, mat2), layout)
@ -905,15 +925,10 @@ def tuned_mixed_mm(mat1, mat2, mat2_dtype):
choices = []
if not skip_triton:
mm_heuristic = V.choices.get_config_heuristics(device_type)
mixed_mm_configs = V.choices.get_mixed_mm_configs(device_type)
b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "")
if static_shape and inductor_config.mixed_mm_choice == "heuristic":
choices = []
config = try_heuristic(
m, n, k, choices, mat1, mat2, mat2_dtype, layout, mm_heuristic
)
config = try_heuristic(m, n, k, choices, mat1, mat2, mat2_dtype, layout)
if config is not None:
mm_template.maybe_append_choice(
choices,
@ -924,13 +939,12 @@ def tuned_mixed_mm(mat1, mat2, mat2_dtype):
choices.append(fallback)
has_int8_tensor = _is_int8_mat(mat1) or _is_int8_mat(mat2)
for config in mixed_mm_configs(
m,
n,
k,
has_int8_tensor=has_int8_tensor,
**mm_config_kwargs(device_type, _is_large_block_for_cpu),
**mm_config_kwargs(ir.get_device_type(mat1)),
):
mm_template.maybe_append_choice(
choices,
@ -987,15 +1001,13 @@ def tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype, *, layout=None):
m, n, k, layout, mat1, mat2, mat3 = mm_args(
mat1, mat2, mat3, layout=layout, out_dtype=out_dtype
)
device_type = ir.get_device_type(mat1)
def mul_epilogue(v1, v2):
return V.ops.mul(v1, v2)
choices: list[dict[Any, Any]] = []
int8_mm_configs = V.choices.get_int8_mm_configs(device_type)
for config in int8_mm_configs(
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))
):
mm_template.maybe_append_choice(
choices,

View File

@ -1,22 +1,437 @@
# mypy: allow-untyped-defs
import functools
import itertools
import logging
from typing import Any
from collections.abc import Sequence
from typing import Any, cast
import sympy
import torch
from torch._inductor.select_algorithm import realize_inputs
from torch._inductor.virtualized import V
from torch.utils._ordered_set import OrderedSet
from .. import config as inductor_config
from ..codegen.wrapper import PythonWrapperCodegen
from ..ir import Layout
from ..utils import ceildiv as cdiv, get_num_sms, TMA_DESCRIPTOR_SIZE
from ..runtime.runtime_utils import next_power_of_2
from ..utils import (
ceildiv as cdiv,
get_backend_num_stages,
get_num_sms,
TMA_DESCRIPTOR_SIZE,
)
log = logging.getLogger(__name__)
def triton_config(num_stages, num_warps, **kwargs):
from triton import Config # type: ignore[attr-defined]
return Config(kwargs, num_stages=num_stages, num_warps=num_warps)
def build_rocm_gemm_configs(configs):
rocm_num_stages = get_backend_num_stages()
return tuple((c[0], c[1], c[2], rocm_num_stages, c[4]) for c in configs)
def filtered_configs(
m: int,
n: int,
k: int,
configs: Sequence[tuple[int, int, int, int, int]],
has_int8_tensor=False,
scale=1,
exclude=lambda m, n, k: False,
):
"""
Heuristic to shrink configs when they are bigger than the input size
:param scale: scale factor applied to the config values
:param exclude: whether a given config should be excluded
"""
from torch._inductor import config
max_mm_configs = config.test_configs.max_mm_configs
min_block_size = 16
# block_k=16 seems to be causing issues
# see: https://github.com/triton-lang/triton/issues/2156#issuecomment-1695897424
min_block_size_k = 32 if has_int8_tensor else 16
m = max(
next_power_of_2(
V.graph.sizevars.size_hint(
m, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
)
),
min_block_size,
)
n = max(
next_power_of_2(
V.graph.sizevars.size_hint(
n, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
)
),
min_block_size,
)
k = max(
next_power_of_2(
V.graph.sizevars.size_hint(
k, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
)
),
min_block_size_k,
)
used = OrderedSet[tuple[int, int, int, int, int, int]]()
for block_m, block_n, block_k, num_stages, num_warps in configs:
# shrink configs for small sizes
block_m = max(min(int(block_m * scale), m), min_block_size)
block_n = max(min(int(block_n * scale), n), min_block_size)
block_k = max(min(int(block_k * scale), k), min_block_size_k)
if exclude(block_m, block_n, block_k):
continue
# each warp computes 16x16 tile = 256
num_warps = min(num_warps, block_m * block_n // 256)
if torch.version.hip:
for matrix_instr_nonkdim in [0, 16]:
if matrix_instr_nonkdim != 0 and (
block_m % matrix_instr_nonkdim != 0
or block_n % matrix_instr_nonkdim != 0
):
# block_m and block_n must be a multiple of matrix_instr_nonkdim
continue
if (
block_m,
block_n,
block_k,
num_stages,
num_warps,
matrix_instr_nonkdim,
) not in used and (
max_mm_configs is None or len(used) < max_mm_configs
):
used.add(
(
block_m,
block_n,
block_k,
num_stages,
num_warps,
matrix_instr_nonkdim,
)
)
yield triton_config(
BLOCK_M=block_m,
BLOCK_N=block_n,
BLOCK_K=block_k,
num_stages=num_stages,
num_warps=num_warps,
matrix_instr_nonkdim=matrix_instr_nonkdim,
)
else:
if (block_m, block_n, block_k, num_stages, num_warps, 0) not in used and (
max_mm_configs is None or len(used) < max_mm_configs
):
used.add((block_m, block_n, block_k, num_stages, num_warps, 0))
yield triton_config(
BLOCK_M=block_m,
BLOCK_N=block_n,
BLOCK_K=block_k,
num_stages=num_stages,
num_warps=num_warps,
)
# List of dictionaries to store the kernel configs. Configs that evaluate to true
# will be utilised on the target platform. The configs are as follows:
# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
mm_kernel_configs = (
[
{"config": (32, 32, 16, 1, 2), "cond": True},
{"config": (32, 32, 128, 2, 4), "cond": True},
{"config": (32, 64, 32, 5, 8), "cond": True},
{"config": (64, 32, 32, 5, 8), "cond": True},
{"config": (64, 32, 128, 5, 4), "cond": True},
{"config": (64, 64, 16, 2, 4), "cond": True},
{"config": (64, 64, 32, 2, 4), "cond": True},
{"config": (64, 64, 64, 3, 8), "cond": True},
{"config": (64, 64, 128, 5, 4), "cond": True},
{"config": (64, 128, 32, 3, 4), "cond": True},
{"config": (64, 128, 32, 4, 8), "cond": True},
{"config": (64, 128, 64, 3, 4), "cond": True},
{"config": (64, 128, 128, 4, 4), "cond": True},
{"config": (128, 64, 32, 3, 4), "cond": True},
{"config": (128, 64, 32, 4, 8), "cond": True},
{"config": (128, 128, 32, 2, 8), "cond": True},
{"config": (128, 128, 32, 3, 4), "cond": True},
{"config": (128, 128, 64, 3, 4), "cond": True},
{"config": (128, 128, 64, 5, 8), "cond": True},
]
if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE"
else [
{"config": (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps), "cond": True}
for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
[16, 32, 64, 128, 256], repeat=3
)
for num_stages in [1, 2, 3, 4, 5]
for num_warps in [2, 4, 8]
]
)
# these are only used in tuned_mm when AutoHeuristic is enabled
# the idea is that when AutoHeuristic collects data to learn a heuristic, more configs are autotuned
# when the learned heuristic is used, the learned heuristic reduces the number of configs down to 10
# which saves compilation time (since less configs are autotuned) and potentially increase performance
# because the learned heuristic might predict a config that is not part mm_configs
extra_mm_kernel_configs = [
{"config": (16, 32, 16, 3, 2), "cond": True},
{"config": (16, 32, 32, 4, 2), "cond": True},
{"config": (16, 32, 32, 5, 2), "cond": True},
{"config": (64, 64, 128, 3, 4), "cond": True},
{"config": (128, 64, 32, 2, 2), "cond": True},
{"config": (128, 64, 64, 3, 8), "cond": True},
{"config": (128, 64, 128, 4, 8), "cond": True},
{"config": (128, 128, 32, 4, 4), "cond": True},
{"config": (128, 128, 64, 3, 8), "cond": True},
{"config": (128, 128, 64, 5, 4), "cond": True},
]
int8_mm_kernel_configs = [
{"config": (64, 64, 32, 2, 4), "cond": True},
{"config": (64, 128, 32, 3, 4), "cond": True},
{"config": (128, 64, 32, 3, 4), "cond": True},
{"config": (64, 128, 32, 4, 8), "cond": True},
{"config": (128, 64, 32, 4, 8), "cond": True},
{"config": (64, 32, 32, 5, 8), "cond": True},
{"config": (32, 64, 32, 5, 8), "cond": True},
{"config": (128, 128, 32, 2, 8), "cond": True},
{"config": (64, 64, 64, 3, 8), "cond": True},
# {"config": (32, 32, 128, 2, 4), "cond": True},
# {"config": (64, 64, 16, 2, 4), "cond": True},
# {"config": (32, 32, 16, 1, 2), "cond": True},
{"config": (128, 256, 128, 3, 8), "cond": True},
{"config": (256, 128, 128, 3, 8), "cond": True},
]
# Mixed precision kernel configs for small sizes of m for mm's like (16, 8192) x (8192, 8192).
mixed_mm_kernel_configs_small_m = [
{"config": (16, 128, 256, 3, 4), "cond": True},
{"config": (16, 128, 256, 5, 8), "cond": True},
]
mixed_mm_kernel_configs = (
mm_kernel_configs + mixed_mm_kernel_configs_small_m
if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE"
else mm_kernel_configs
)
persistent_mm_kernel_configs = [
{"config": (128, 256, 64, 3, 8), "cond": True},
{"config": (128, 128, 64, 3, 8), "cond": True},
{"config": (128, 128, 128, 3, 8), "cond": True},
{"config": (128, 128, 128, 3, 4), "cond": True},
{"config": (128, 128, 64, 4, 8), "cond": True},
]
scaled_mm_kernel_configs = [
{"config": (128, 256, 32, 3, 8), "cond": True},
{"config": (256, 128, 32, 3, 8), "cond": True},
{"config": (256, 64, 32, 4, 4), "cond": True},
{"config": (64, 256, 32, 4, 4), "cond": True},
{"config": (128, 128, 32, 4, 4), "cond": True},
{"config": (128, 64, 32, 4, 4), "cond": True},
{"config": (64, 128, 32, 4, 4), "cond": True},
{"config": (128, 32, 32, 4, 4), "cond": True},
{"config": (64, 32, 32, 5, 2), "cond": True},
{"config": (256, 128, 128, 3, 8), "cond": True},
{"config": (256, 64, 128, 4, 4), "cond": True},
{"config": (64, 256, 128, 4, 4), "cond": True},
{"config": (128, 128, 128, 4, 4), "cond": True},
{"config": (128, 64, 64, 4, 4), "cond": True},
{"config": (64, 128, 64, 4, 4), "cond": True},
{"config": (128, 32, 64, 4, 4), "cond": True},
{"config": (64, 32, 64, 5, 2), "cond": True},
{"config": (16, 32, 32, 2, 2), "cond": True},
{"config": (16, 64, 32, 2, 2), "cond": True},
{"config": (16, 128, 32, 2, 4), "cond": True},
{"config": (16, 256, 32, 2, 4), "cond": True},
{"config": (16, 32, 64, 2, 2), "cond": True},
{"config": (16, 64, 64, 2, 2), "cond": True},
{"config": (16, 128, 64, 2, 4), "cond": True},
{"config": (16, 256, 64, 2, 4), "cond": True},
{"config": (32, 32, 32, 2, 2), "cond": True},
{"config": (32, 64, 32, 2, 2), "cond": True},
{"config": (32, 128, 32, 2, 4), "cond": True},
{"config": (32, 256, 32, 2, 4), "cond": True},
{"config": (32, 32, 64, 2, 2), "cond": True},
{"config": (32, 64, 64, 2, 2), "cond": True},
{"config": (32, 128, 64, 2, 4), "cond": True},
{"config": (32, 256, 64, 2, 4), "cond": True},
{"config": (16, 32, 32, 3, 2), "cond": True},
{"config": (16, 64, 32, 3, 2), "cond": True},
{"config": (16, 128, 32, 3, 4), "cond": True},
{"config": (16, 256, 32, 3, 4), "cond": True},
{"config": (16, 32, 64, 3, 2), "cond": True},
{"config": (16, 64, 64, 3, 2), "cond": True},
{"config": (16, 128, 64, 3, 4), "cond": True},
{"config": (16, 256, 64, 3, 4), "cond": True},
{"config": (32, 32, 32, 3, 2), "cond": True},
{"config": (32, 64, 32, 3, 2), "cond": True},
{"config": (32, 128, 32, 3, 4), "cond": True},
{"config": (32, 256, 32, 3, 4), "cond": True},
{"config": (32, 32, 64, 3, 2), "cond": True},
{"config": (32, 64, 64, 3, 2), "cond": True},
{"config": (32, 128, 64, 3, 4), "cond": True},
{"config": (32, 256, 64, 3, 4), "cond": True},
{"config": (16, 32, 32, 4, 2), "cond": True},
{"config": (16, 64, 32, 4, 2), "cond": True},
{"config": (16, 128, 32, 4, 4), "cond": True},
{"config": (16, 256, 32, 4, 4), "cond": True},
{"config": (16, 32, 64, 4, 2), "cond": True},
{"config": (16, 64, 64, 4, 2), "cond": True},
{"config": (16, 128, 64, 4, 4), "cond": True},
{"config": (16, 256, 64, 4, 4), "cond": True},
{"config": (32, 32, 32, 4, 2), "cond": True},
{"config": (32, 64, 32, 4, 2), "cond": True},
{"config": (32, 128, 32, 4, 4), "cond": True},
{"config": (32, 256, 32, 4, 4), "cond": True},
{"config": (32, 32, 64, 4, 2), "cond": True},
{"config": (32, 64, 64, 4, 2), "cond": True},
{"config": (32, 128, 64, 4, 4), "cond": True},
{"config": (32, 256, 64, 4, 4), "cond": True},
{"config": (16, 32, 32, 5, 2), "cond": True},
{"config": (16, 64, 32, 5, 2), "cond": True},
{"config": (16, 128, 32, 5, 4), "cond": True},
{"config": (16, 256, 32, 5, 4), "cond": True},
{"config": (16, 32, 64, 5, 2), "cond": True},
{"config": (16, 64, 64, 5, 2), "cond": True},
{"config": (16, 128, 64, 5, 4), "cond": True},
{"config": (16, 256, 64, 5, 4), "cond": True},
{"config": (32, 32, 32, 5, 2), "cond": True},
{"config": (32, 64, 32, 5, 2), "cond": True},
{"config": (32, 128, 32, 5, 4), "cond": True},
{"config": (32, 256, 32, 5, 4), "cond": True},
{"config": (32, 32, 64, 5, 2), "cond": True},
{"config": (32, 64, 64, 5, 2), "cond": True},
{"config": (32, 128, 64, 5, 4), "cond": True},
{"config": (32, 256, 64, 5, 4), "cond": True},
{"config": (16, 32, 32, 6, 2), "cond": True},
{"config": (16, 64, 32, 6, 2), "cond": True},
{"config": (16, 128, 32, 6, 4), "cond": True},
{"config": (16, 256, 32, 6, 4), "cond": True},
{"config": (16, 32, 64, 6, 2), "cond": True},
{"config": (16, 64, 64, 6, 2), "cond": True},
{"config": (16, 128, 64, 6, 4), "cond": True},
{"config": (16, 256, 64, 6, 4), "cond": True},
{"config": (32, 32, 32, 6, 2), "cond": True},
{"config": (32, 64, 32, 6, 2), "cond": True},
{"config": (32, 128, 32, 6, 4), "cond": True},
{"config": (32, 256, 32, 6, 4), "cond": True},
{"config": (32, 32, 64, 6, 2), "cond": True},
{"config": (32, 64, 64, 6, 2), "cond": True},
{"config": (32, 128, 64, 6, 4), "cond": True},
{"config": (32, 256, 64, 6, 4), "cond": True},
]
scaled_persistent_mm_kernel_configs = [
{"config": (128, 128, 64, 3, 8), "cond": True},
{"config": (128, 128, 128, 3, 8), "cond": True},
{"config": (128, 128, 128, 4, 8), "cond": True},
{"config": (128, 128, 128, 4, 4), "cond": True},
{"config": (128, 128, 128, 3, 4), "cond": True},
{"config": (128, 128, 128, 5, 4), "cond": True},
{"config": (128, 128, 128, 5, 8), "cond": True},
{"config": (128, 128, 128, 6, 8), "cond": True},
{"config": (128, 128, 64, 4, 8), "cond": True},
]
# Create filtered list of configs based on cond evaluation
mm_platform_configs = tuple(
cast(tuple[int, int, int, int, int], config["config"])
for config in mm_kernel_configs
if config["cond"]
)
extra_mm_platform_configs = tuple(
cast(tuple[int, int, int, int, int], config["config"])
for config in extra_mm_kernel_configs
if config["cond"]
)
int8_platform_configs = tuple(
cast(tuple[int, int, int, int, int], config["config"])
for config in int8_mm_kernel_configs
if config["cond"]
)
mixed_mm_platform_configs = tuple(
cast(tuple[int, int, int, int, int], config["config"])
for config in mixed_mm_kernel_configs
if config["cond"]
)
persistent_mm_platform_configs = tuple(
cast(tuple[int, int, int, int, int], config["config"])
for config in persistent_mm_kernel_configs
if config["cond"]
)
scaled_mm_platform_configs = tuple(
cast(tuple[int, int, int, int, int], config["config"])
for config in scaled_mm_kernel_configs
if config["cond"]
)
scaled_persistent_mm_platform_configs = tuple(
cast(tuple[int, int, int, int, int], config["config"])
for config in scaled_persistent_mm_kernel_configs
if config["cond"]
)
# On ROCm convert num_stages to improve performance
if torch.version.hip and torch.cuda.is_available():
mm_platform_configs = build_rocm_gemm_configs(mm_platform_configs)
extra_mm_platform_configs = build_rocm_gemm_configs(extra_mm_platform_configs)
int8_platform_configs = build_rocm_gemm_configs(int8_platform_configs)
mixed_mm_platform_configs = build_rocm_gemm_configs(mixed_mm_platform_configs)
scaled_mm_platform_configs = build_rocm_gemm_configs(scaled_mm_platform_configs)
mm_configs = functools.partial(
filtered_configs,
configs=mm_platform_configs,
)
extra_mm_configs = functools.partial(
filtered_configs,
configs=extra_mm_platform_configs,
)
int8_mm_configs = functools.partial(
filtered_configs,
configs=int8_platform_configs,
)
mixed_mm_configs = functools.partial(
filtered_configs,
configs=mixed_mm_platform_configs,
)
persistent_mm_configs = functools.partial(
filtered_configs,
configs=persistent_mm_platform_configs,
)
scaled_mm_configs = functools.partial(
filtered_configs,
configs=scaled_mm_platform_configs,
)
scaled_persistent_mm_configs = functools.partial(
filtered_configs,
configs=scaled_persistent_mm_platform_configs,
)
def mm_grid(m, n, meta):
"""
The CUDA grid size for matmul triton templates.
@ -115,15 +530,6 @@ def mm_args(
return [m, n, k, layout, mat1, mat2, *others]
def mm_config_kwargs(device, exclude_condition):
if device == "cpu":
return {
"scale": 0.5,
"exclude": exclude_condition,
}
return {}
def addmm_epilogue(dtype, alpha, beta):
def epilogue(acc, bias):
if alpha != 1:

View File

@ -1,8 +1,8 @@
# mypy: allow-untyped-defs
import functools
import torch
from .. import ir
from ..lowering import lowerings
from ..select_algorithm import (
autotune_select_algorithm,
@ -112,14 +112,101 @@ mm_plus_mm_template = TritonTemplate(
)
@functools.lru_cache(None)
def mm_configs():
import triton
# List of dictionaries to store the kernel configs. Configs that evaluate to true
# will be utilised on the target platform
mm_triton_configs = [
{
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
"num_stages": 2,
"num_warps": 4,
"cond": True,
},
{
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
"num_stages": 3,
"num_warps": 8,
"cond": True,
},
{
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
"num_stages": 4,
"num_warps": 16,
"cond": True,
},
{
"config": {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32},
"num_stages": 4,
"num_warps": 8,
"cond": True,
},
{
"config": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32},
"num_stages": 4,
"num_warps": 8,
"cond": True,
},
{
"config": {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32},
"num_stages": 1,
"num_warps": 8,
"cond": True,
},
{
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64},
"num_stages": 1,
"num_warps": 8,
"cond": True,
},
{
"config": {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 128},
"num_stages": 1,
"num_warps": 8,
"cond": torch.version.hip is None,
},
{
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 16},
"num_stages": 2,
"num_warps": 4,
"cond": True,
},
{
"config": {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 16},
"num_stages": 1,
"num_warps": 2,
"cond": True,
},
]
# Filter out configs in which cond evaluates to true
# On ROCm convert num_stages to 1 as pipelining provides no benefit
if torch.version.hip:
filtered_configs = [
triton.Config(c["config"], num_stages=1, num_warps=c["num_warps"])
for c in mm_triton_configs
if c["cond"]
]
else:
filtered_configs = [
triton.Config(
c["config"], num_stages=c["num_stages"], num_warps=c["num_warps"]
)
for c in mm_triton_configs
if c["cond"]
]
return filtered_configs
def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
"""
Computes mm(mat1, mat2) + mm(mat3, mat4)
"""
m1, n1, k1, layout1, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
m2, n2, _, layout2, mat3, mat4 = mm_args(mat3, mat4, layout=layout)
device_type = ir.get_device_type(mat1)
# Optimization is optional, because we can always just not do the fusion
if (
m1 * n1 == 0
@ -144,9 +231,6 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
if use_aten_gemm_kernels()
else []
)
mm_configs = V.choices.get_mm_plus_mm_configs(device_type)
if use_triton_template(layout1):
for config in mm_configs():
# see https://github.com/openai/triton/issues/1298

View File

@ -10,7 +10,7 @@ from torch.utils._triton import has_triton_tma_device
from .. import config as inductor_config
from ..config import triton as triton_config
from ..ir import _IntLike, ChoiceCaller, get_device_type, Layout, StorageBox, TensorBox
from ..ir import _IntLike, ChoiceCaller, Layout, StorageBox, TensorBox
from ..lowering import add_layout_constraint, constrain_to_fx_strides, register_lowering
from ..select_algorithm import (
autotune_select_algorithm,
@ -27,8 +27,14 @@ from ..utils import (
use_ck_gemm_template,
use_triton_template,
)
from ..virtualized import V
from .mm_common import _is_static_problem, mm_args, mm_grid, persistent_mm_grid
from .mm_common import (
_is_static_problem,
mm_args,
mm_grid,
persistent_mm_grid,
scaled_mm_configs,
scaled_persistent_mm_configs,
)
log = logging.getLogger(__name__)
@ -509,7 +515,6 @@ def tuned_scaled_mm(
m, n, k, layout, mat_a, mat_b = mm_args(
mat_a, mat_b, layout=layout, out_dtype=out_dtype
)
device_type = get_device_type(mat_a)
check_supported_striding(mat_a, mat_b)
@ -535,11 +540,6 @@ def tuned_scaled_mm(
_, is_nonzero = _is_static_problem(layout)
scaled_mm_configs = V.choices.get_scaled_mm_configs(device_type)
scaled_persistent_mm_configs = V.choices.get_scaled_persistent_mm_configs(
device_type
)
if is_nonzero and use_triton_template(layout, enable_float8=True):
if use_persistent_tma(k, bias is not None):
for config in scaled_persistent_mm_configs(m, n, k):

View File

@ -2,10 +2,8 @@
import logging
from typing import TYPE_CHECKING
from .. import ir
from ..select_algorithm import autotune_select_algorithm, TritonTemplate
from ..virtualized import V
from .mm_common import mm_args, mm_grid, mm_options
from .mm_common import mm_args, mm_configs, mm_grid, mm_options
if TYPE_CHECKING:
@ -77,13 +75,8 @@ uint4x2_mixed_mm_template = TritonTemplate(
def tuned_uint4x2_mixed_mm(mat1, mat2, mat2_mm_shape, mat2_dtype):
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=None, use_4x2_dim=True)
device_type = ir.get_device_type(mat1)
choices: list[ChoiceCaller] = []
b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "")
mm_configs = V.choices.get_base_mm_configs(device_type)
for config in mm_configs(m, n, k):
uint4x2_mixed_mm_template.maybe_append_choice(
choices,

View File

@ -36,8 +36,7 @@ def set_driver_to_gpu():
def get_backend_options():
from triton.runtime import driver
driver = triton.runtime.driver
target = driver.active.get_current_target()
backend = triton.compiler.compiler.make_backend(target)
options = backend.parse_options(dict())

View File

@ -1,613 +0,0 @@
from __future__ import annotations
import itertools
from collections import namedtuple
from functools import partial
from threading import Lock
from typing import Any, Callable, Generator, List, Sequence, Tuple, Type, TYPE_CHECKING
from torch.utils._ordered_set import OrderedSet
from . import config
from .utils import get_backend_num_stages
from .virtualized import V
if TYPE_CHECKING:
from triton import Config as TritonConfig
class BaseConfigSingleton(type):
"""
Thread-safe implementation of single to be used in the config heuristic subclasses
to ensure heavy __init__ calls are not repeatedly run
"""
_instances: dict[Type[Any], Any] = {}
_lock: Lock = Lock()
def __call__(
cls: BaseConfigSingleton, *args: Any, **kwargs: Any
) -> BaseConfigHeuristic:
with cls._lock:
if cls not in cls._instances:
instance = super().__call__()
cls._instances[cls] = instance
return cls._instances[cls]
Config = namedtuple(
"Config", ["block_m", "block_n", "block_k", "num_stages", "num_warps"]
)
class BaseConfigHeuristic(metaclass=BaseConfigSingleton):
"""
Base class for mm_configs, device specific triton kernels config inherit from here
"""
def __init__(self) -> None:
# List of dictionaries to store the kernel configs. Configs that evaluate to true
# will be utilised on the target platform. The configs are as follows:
# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
self.mm_configs = [
Config(32, 32, 16, 1, 2),
Config(32, 32, 128, 2, 4),
Config(32, 64, 32, 5, 8),
Config(64, 32, 32, 5, 8),
Config(64, 32, 128, 5, 4),
Config(64, 64, 16, 2, 4),
Config(64, 64, 32, 2, 4),
Config(64, 64, 64, 3, 8),
Config(64, 64, 128, 5, 4),
Config(64, 128, 32, 3, 4),
Config(64, 128, 32, 4, 8),
Config(64, 128, 64, 3, 4),
Config(64, 128, 128, 4, 4),
Config(128, 64, 32, 3, 4),
Config(128, 64, 32, 4, 8),
Config(128, 128, 32, 2, 8),
Config(128, 128, 32, 3, 4),
Config(128, 128, 64, 3, 4),
Config(128, 128, 64, 5, 8),
]
# Exhaustive search for mm configs
self.exhaustive_configs = [
Config(BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
[16, 32, 64, 128, 256], repeat=3
)
for num_stages in [1, 2, 3, 4, 5]
for num_warps in [2, 4, 8]
]
# these are only used in tuned_mm when AutoHeuristic is enabled
# the idea is that when AutoHeuristic collects data to learn a heuristic, more configs are autotuned
# when the learned heuristic is used, the learned heuristic reduces the number of configs down to 10
# which saves compilation time (since less configs are autotuned) and potentially increase performance
# because the learned heuristic might predict a config that is not part mm_configs
self.extra_mm_configs = [
Config(16, 32, 16, 3, 2),
Config(16, 32, 32, 4, 2),
Config(16, 32, 32, 5, 2),
Config(64, 64, 128, 3, 4),
Config(128, 64, 32, 2, 2),
Config(128, 64, 64, 3, 8),
Config(128, 64, 128, 4, 8),
Config(128, 128, 32, 4, 4),
Config(128, 128, 64, 3, 8),
Config(128, 128, 64, 5, 4),
]
self.int8_mm_configs = [
Config(64, 64, 32, 2, 4),
Config(64, 128, 32, 3, 4),
Config(128, 64, 32, 3, 4),
Config(64, 128, 32, 4, 8),
Config(128, 64, 32, 4, 8),
Config(64, 32, 32, 5, 8),
Config(32, 64, 32, 5, 8),
Config(128, 128, 32, 2, 8),
Config(64, 64, 64, 3, 8),
Config(128, 256, 128, 3, 8),
Config(256, 128, 128, 3, 8),
]
self.mixed_mm_configs = [
Config(16, 128, 256, 3, 4),
Config(16, 128, 256, 5, 8),
]
self.persistent_mm_configs = [
Config(128, 256, 64, 3, 8),
Config(128, 128, 64, 3, 8),
Config(128, 128, 128, 3, 8),
Config(128, 128, 128, 3, 4),
Config(128, 128, 64, 4, 8),
]
self.scaled_mm_configs = [
Config(128, 256, 32, 3, 8),
Config(256, 128, 32, 3, 8),
Config(256, 64, 32, 4, 4),
Config(64, 256, 32, 4, 4),
Config(128, 128, 32, 4, 4),
Config(128, 64, 32, 4, 4),
Config(64, 128, 32, 4, 4),
Config(128, 32, 32, 4, 4),
Config(64, 32, 32, 5, 2),
Config(256, 128, 128, 3, 8),
Config(256, 64, 128, 4, 4),
Config(64, 256, 128, 4, 4),
Config(128, 128, 128, 4, 4),
Config(128, 64, 64, 4, 4),
Config(64, 128, 64, 4, 4),
Config(128, 32, 64, 4, 4),
Config(64, 32, 64, 5, 2),
Config(16, 32, 32, 2, 2),
Config(16, 64, 32, 2, 2),
Config(16, 128, 32, 2, 4),
Config(16, 256, 32, 2, 4),
Config(16, 32, 64, 2, 2),
Config(16, 64, 64, 2, 2),
Config(16, 128, 64, 2, 4),
Config(16, 256, 64, 2, 4),
Config(32, 32, 32, 2, 2),
Config(32, 64, 32, 2, 2),
Config(32, 128, 32, 2, 4),
Config(32, 256, 32, 2, 4),
Config(32, 32, 64, 2, 2),
Config(32, 64, 64, 2, 2),
Config(32, 128, 64, 2, 4),
Config(32, 256, 64, 2, 4),
Config(16, 32, 32, 3, 2),
Config(16, 64, 32, 3, 2),
Config(16, 128, 32, 3, 4),
Config(16, 256, 32, 3, 4),
Config(16, 32, 64, 3, 2),
Config(16, 64, 64, 3, 2),
Config(16, 128, 64, 3, 4),
Config(16, 256, 64, 3, 4),
Config(32, 32, 32, 3, 2),
Config(32, 64, 32, 3, 2),
Config(32, 128, 32, 3, 4),
Config(32, 256, 32, 3, 4),
Config(32, 32, 64, 3, 2),
Config(32, 64, 64, 3, 2),
Config(32, 128, 64, 3, 4),
Config(32, 256, 64, 3, 4),
Config(16, 32, 32, 4, 2),
Config(16, 64, 32, 4, 2),
Config(16, 128, 32, 4, 4),
Config(16, 256, 32, 4, 4),
Config(16, 32, 64, 4, 2),
Config(16, 64, 64, 4, 2),
Config(16, 128, 64, 4, 4),
Config(16, 256, 64, 4, 4),
Config(32, 32, 32, 4, 2),
Config(32, 64, 32, 4, 2),
Config(32, 128, 32, 4, 4),
Config(32, 256, 32, 4, 4),
Config(32, 32, 64, 4, 2),
Config(32, 64, 64, 4, 2),
Config(32, 128, 64, 4, 4),
Config(32, 256, 64, 4, 4),
Config(16, 32, 32, 5, 2),
Config(16, 64, 32, 5, 2),
Config(16, 128, 32, 5, 4),
Config(16, 256, 32, 5, 4),
Config(16, 32, 64, 5, 2),
Config(16, 64, 64, 5, 2),
Config(16, 128, 64, 5, 4),
Config(16, 256, 64, 5, 4),
Config(32, 32, 32, 5, 2),
Config(32, 64, 32, 5, 2),
Config(32, 128, 32, 5, 4),
Config(32, 256, 32, 5, 4),
Config(32, 32, 64, 5, 2),
Config(32, 64, 64, 5, 2),
Config(32, 128, 64, 5, 4),
Config(32, 256, 64, 5, 4),
Config(16, 32, 32, 6, 2),
Config(16, 64, 32, 6, 2),
Config(16, 128, 32, 6, 4),
Config(16, 256, 32, 6, 4),
Config(16, 32, 64, 6, 2),
Config(16, 64, 64, 6, 2),
Config(16, 128, 64, 6, 4),
Config(16, 256, 64, 6, 4),
Config(32, 32, 32, 6, 2),
Config(32, 64, 32, 6, 2),
Config(32, 128, 32, 6, 4),
Config(32, 256, 32, 6, 4),
Config(32, 32, 64, 6, 2),
Config(32, 64, 64, 6, 2),
Config(32, 128, 64, 6, 4),
Config(32, 256, 64, 6, 4),
]
self.scaled_persistent_mm_configs = [
Config(128, 128, 64, 3, 8),
Config(128, 128, 128, 3, 8),
Config(128, 128, 128, 4, 8),
Config(128, 128, 128, 4, 4),
Config(128, 128, 128, 3, 4),
Config(128, 128, 128, 5, 4),
Config(128, 128, 128, 5, 8),
Config(128, 128, 128, 6, 8),
Config(128, 128, 64, 4, 8),
]
# TODO: Unify with other gemm patterns, mm_plus_mm currently follows
# slightly different pattern than rest
self.mm_plus_mm_configs = [
Config(64, 64, 32, 2, 4),
Config(64, 64, 32, 3, 8),
Config(64, 64, 32, 4, 16),
Config(64, 32, 32, 4, 8),
Config(32, 64, 32, 4, 8),
Config(128, 128, 32, 1, 8),
Config(64, 64, 64, 1, 8),
Config(32, 32, 128, 1, 8),
Config(64, 64, 16, 2, 4),
Config(32, 32, 16, 1, 2),
]
self.conv_configs = [
Config(64, 256, 16, 2, 4),
Config(256, 64, 16, 2, 4),
Config(1024, 16, 16, 1, 8),
Config(128, 128, 32, 2, 8),
Config(64, 64, 32, 2, 4),
Config(64, 256, 32, 2, 8),
Config(256, 64, 32, 2, 8),
]
def _finalize_mm_configs(
self,
configs: List[Config],
) -> Generator[TritonConfig, None, None]:
"""
Finalizes configs after scaling, applying additional constraints.
"""
used = OrderedSet[Config]()
max_mm_configs = config.test_configs.max_mm_configs
for block_m, block_n, block_k, num_stages, num_warps in configs:
# Each warp computes a 16x16 tile = 256 elements
num_warps = min(num_warps, block_m * block_n // 256)
if (block_m, block_n, block_k, num_stages, num_warps, 0) not in used and (
max_mm_configs is None or len(used) < max_mm_configs
):
used.add(Config(block_m, block_n, block_k, num_stages, num_warps))
yield self.triton_config(
BLOCK_M=block_m,
BLOCK_N=block_n,
BLOCK_K=block_k,
num_stages=num_stages,
num_warps=num_warps,
)
def _scale_mm_configs(
self,
m: int,
n: int,
k: int,
configs: Sequence[Config],
scale: float,
has_int8_tensor: bool,
exclude: Callable[[int, int, int], bool],
) -> List[Config]:
"""
Scales and filters matrix multiplication configs based on input size.
"""
from .runtime.runtime_utils import next_power_of_2
min_block_size = 16
min_block_size_k = 32 if has_int8_tensor else 16
m = max(
next_power_of_2(
V.graph.sizevars.size_hint(
m, fallback=config.unbacked_symint_fallback # type: ignore[arg-type]
)
),
min_block_size,
)
n = max(
next_power_of_2(
V.graph.sizevars.size_hint(
n, fallback=config.unbacked_symint_fallback # type: ignore[arg-type]
)
),
min_block_size,
)
k = max(
next_power_of_2(
V.graph.sizevars.size_hint(
k, fallback=config.unbacked_symint_fallback # type: ignore[arg-type]
)
),
min_block_size_k,
)
scaled_configs = []
for c in configs:
scaled_config = c._replace(
block_m=max(min(int(c.block_m * scale), m), min_block_size),
block_n=max(min(int(c.block_n * scale), n), min_block_size),
block_k=max(min(int(c.block_k * scale), k), min_block_size_k),
)
if not exclude(
scaled_config.block_m, scaled_config.block_n, scaled_config.block_k
):
scaled_configs.append(scaled_config)
return scaled_configs
def preprocess_mm_configs(
self,
m: int,
n: int,
k: int,
configs: Sequence[Config],
has_int8_tensor: bool = False,
scale: int = 1,
exclude: Callable[[int, int, int], bool] = lambda m, n, k: False,
) -> Generator[TritonConfig, None, None]:
scaled_configs = self._scale_mm_configs(
m, n, k, configs, scale, has_int8_tensor, exclude
)
return self._finalize_mm_configs(scaled_configs)
def triton_config(
self, num_stages: int, num_warps: int, **kwargs: Any
) -> TritonConfig:
from triton import Config as TritonConfig # type: ignore[attr-defined]
return TritonConfig(kwargs, num_stages=num_stages, num_warps=num_warps)
def get_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self.preprocess_mm_configs, configs=self.mm_configs)
def get_exhaustive_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self.preprocess_mm_configs, configs=self.exhaustive_configs)
def get_extra_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self.preprocess_mm_configs, configs=self.extra_mm_configs)
def get_int8_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self.preprocess_mm_configs, configs=self.int8_mm_configs)
def get_mixed_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
mm_configs = (
self.mm_configs + self.mixed_mm_configs
if config.max_autotune_gemm_search_space == "EXHAUSTIVE"
else self.mm_configs
)
return partial(self.preprocess_mm_configs, configs=mm_configs)
def get_persistent_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self.preprocess_mm_configs, configs=self.persistent_mm_configs)
def get_scaled_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self.preprocess_mm_configs, configs=self.scaled_mm_configs)
def get_scaled_persistent_mm_configs(
self,
) -> partial[Generator[TritonConfig, None, None]]:
return partial(
self.preprocess_mm_configs, configs=self.scaled_persistent_mm_configs
)
def get_mm_plus_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self._finalize_mm_configs, configs=self.mm_plus_mm_configs)
def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self.preprocess_mm_configs, configs=self.conv_configs)
def generate_mixed_mm_config(self, m: int, n: int, k: int) -> TritonConfig:
if m <= 16 and n >= 4096 and k >= 4096:
return self.triton_config(
BLOCK_M=16,
BLOCK_N=64,
BLOCK_K=128,
num_stages=5,
num_warps=4,
)
elif m > 16 and m <= 32 and n >= 4096 and k >= 4096:
return self.triton_config(
BLOCK_M=32,
BLOCK_N=32,
BLOCK_K=128,
num_stages=5,
num_warps=4,
)
elif m > 32 and m <= 64 and n >= 4096 and k >= 4096:
return self.triton_config(
BLOCK_M=64,
BLOCK_N=32,
BLOCK_K=128,
num_stages=5,
num_warps=4,
)
class CPUConfigHeuristic(BaseConfigHeuristic):
pass
class CUDAConfigHeuristic(BaseConfigHeuristic):
pass
class ROCmConfigHeuristic(BaseConfigHeuristic):
def __init__(self) -> None:
super().__init__()
self.default_num_stages = get_backend_num_stages()
# Exhaustive search for mm configs
self.exhaustive_configs = [
Config(BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
[16, 32, 64, 128, 256], repeat=3
)
for num_stages in [1, self.default_num_stages]
for num_warps in [4, 8]
]
def _filter_configs(
self, configs: List[Config], new_num_stages: int
) -> List[Config]:
filtered_configs = [
c._replace(num_stages=self.default_num_stages) for c in configs
]
return filtered_configs
def _finalize_mm_configs(
self,
configs: List[Config],
) -> Generator[TritonConfig, None, None]:
used = OrderedSet[Tuple[Config, int]]()
max_mm_configs = config.test_configs.max_mm_configs
for block_m, block_n, block_k, num_stages, num_warps in configs:
# each warp computes 16x16 tile = 256
num_warps = min(num_warps, block_m * block_n // 256)
for matrix_instr_nonkdim in [0, 16]:
if matrix_instr_nonkdim != 0 and (
block_m % matrix_instr_nonkdim != 0
or block_n % matrix_instr_nonkdim != 0
):
# block_m and block_n must be a multiple of matrix_instr_nonkdim
continue
if (
block_m,
block_n,
block_k,
num_stages,
num_warps,
matrix_instr_nonkdim,
) not in used and (
max_mm_configs is None or len(used) < max_mm_configs
):
used.add(
(
Config(
block_m,
block_n,
block_k,
num_stages,
num_warps,
),
matrix_instr_nonkdim,
)
)
yield self.triton_config(
BLOCK_M=block_m,
BLOCK_N=block_n,
BLOCK_K=block_k,
num_stages=num_stages,
num_warps=num_warps,
matrix_instr_nonkdim=matrix_instr_nonkdim,
)
def get_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.mm_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_exhaustive_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.exhaustive_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_extra_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.extra_mm_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_int8_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.int8_mm_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_mixed_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
mm_configs = (
self.mm_configs + self.mixed_mm_configs
if config.max_autotune_gemm_search_space == "EXHAUSTIVE"
else self.mm_configs
)
filtered_configs = self._filter_configs(mm_configs, self.default_num_stages)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_persistent_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.persistent_mm_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_scaled_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.scaled_mm_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_scaled_persistent_mm_configs(
self,
) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.scaled_persistent_mm_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_mm_plus_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(self.mm_plus_mm_configs, 1)
return partial(self._finalize_mm_configs, configs=filtered_configs)
def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.conv_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def generate_mixed_mm_config(self, m: int, n: int, k: int) -> TritonConfig:
if m <= 16 and n >= 4096 and k >= 4096:
return self.triton_config(
BLOCK_M=16,
BLOCK_N=64,
BLOCK_K=128,
num_stages=self.default_num_stages,
num_warps=4,
)
elif m > 16 and m <= 32 and n >= 4096 and k >= 4096:
return self.triton_config(
BLOCK_M=32,
BLOCK_N=32,
BLOCK_K=128,
num_stages=self.default_num_stages,
num_warps=4,
)
elif m > 32 and m <= 64 and n >= 4096 and k >= 4096:
return self.triton_config(
BLOCK_M=64,
BLOCK_N=32,
BLOCK_K=128,
num_stages=self.default_num_stages,
num_warps=4,
)
class XPUConfigHeuristic(BaseConfigHeuristic):
pass