pytorch/torch/_inductor/template_heuristics.py
Nikhil Anil Patel 5acc3e286a [Inductor] Add Additional Configs for persistent+TMA version of Triton mm and addmm (#150587)
Summary:
This PR introduces additional autotuning configurations for the persistent+TMA version of Triton `mm` and `addmm` operations. The new configurations are as follows:
* `(128, 128, 64, 5, 8)`
* `(256, 128, 64, 4, 8)`
* `(128, 128, 64, 5, 4)`

These configurations were selected based on exhaustive autotuning performed on commonly used shapes from an internal foundational model.

While these new configs are generally more performant across the board, we see notable gains a few specific cases:
* In scenarios where `n >> m, k`, the configurations `(128, 128, 64, 5, 8)` and `(256, 128, 64, 4, 8)` tend to produce an additional 5-10% speedup over the aten baseline compared to the original configurations.
* Similarly, the configuration `(128, 128, 64, 5, 4)` yields approximately an 8% improvement in scenarios where k >> m, n.

These enhancements are expected to provide performance benefits across diverse use cases, particularly when compared to the original set of configurations.

Test Plan:
contbuild & OSS CI

Reviewers: paulzhan

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150587
Approved by: https://github.com/PaulZhang12, https://github.com/drisspg, https://github.com/eellison
2025-04-23 18:21:35 +00:00

706 lines
26 KiB
Python

from __future__ import annotations
import dataclasses
import itertools
from functools import partial
from threading import Lock
from typing import Any, Callable, TYPE_CHECKING
from torch.utils._ordered_set import OrderedSet
from . import config
from .utils import get_backend_num_stages
from .virtualized import V
if TYPE_CHECKING:
from collections.abc import Generator
from triton import Config as TritonConfig
@dataclasses.dataclass
class BaseConfig:
"""
Base Gemm configuration used for most backends (CPU, CUDA)
"""
block_m: int
block_n: int
block_k: int
num_stages: int
num_warps: int
@dataclasses.dataclass
class GemmConfig(BaseConfig):
"""
Gemm configuration used for most backends (CPU, CUDA)
"""
group_m: int = 8
ConvConfig = BaseConfig
@dataclasses.dataclass
class ROCmGemmConfig(GemmConfig):
"""
ROCm subclass for GEMMs, with AMD backend specific tuneable kernargs
"""
matrix_instr_nonkdim: int = 16
waves_per_eu: int = 0
kpack: int = 2
@dataclasses.dataclass
class ROCmConvConfig(ConvConfig):
"""
ROCm subclass for Conv, with AMD backend specific tuneable kernargs
"""
matrix_instr_nonkdim: int = 16
waves_per_eu: int = 0
kpack: int = 2
class BaseHeuristicSingleton(type):
"""
Thread-safe implementation of single to be used in the config heuristic subclasses
to ensure heavy __init__ calls are not repeatedly run
"""
_instances: dict[type[Any], Any] = {}
_lock: Lock = Lock()
def __call__(
cls: BaseHeuristicSingleton, *args: Any, **kwargs: Any
) -> BaseConfigHeuristic:
with cls._lock:
if cls not in cls._instances:
instance = super().__call__()
cls._instances[cls] = instance
return cls._instances[cls]
class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
"""
Base class for mm_configs, device specific triton kernels config inherit from here
"""
def __init__(self) -> None:
# List of dictionaries to store the kernel configs. Configs that evaluate to true
# will be utilised on the target platform. The configs are as follows:
# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
self.mm_configs: list[BaseConfig] = [
GemmConfig(32, 32, 16, 1, 2),
GemmConfig(32, 32, 128, 2, 4),
GemmConfig(32, 64, 32, 5, 8),
GemmConfig(64, 32, 32, 5, 8),
GemmConfig(64, 32, 128, 5, 4),
GemmConfig(64, 64, 16, 2, 4),
GemmConfig(64, 64, 32, 2, 4),
GemmConfig(64, 64, 64, 3, 8),
GemmConfig(64, 64, 128, 5, 4),
GemmConfig(64, 128, 32, 3, 4),
GemmConfig(64, 128, 32, 4, 8),
GemmConfig(64, 128, 64, 3, 4),
GemmConfig(64, 128, 128, 4, 4),
GemmConfig(128, 64, 32, 3, 4),
GemmConfig(128, 64, 32, 4, 8),
GemmConfig(128, 128, 32, 2, 8),
GemmConfig(128, 128, 32, 3, 4),
GemmConfig(128, 128, 64, 3, 4),
GemmConfig(128, 128, 64, 5, 8),
]
# Exhaustive search for mm configs
self.exhaustive_configs: list[BaseConfig] = [
GemmConfig(BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps, group_m)
for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
[16, 32, 64, 128, 256], repeat=3
)
for num_stages in [1, 2, 3, 4, 5]
for num_warps in [2, 4, 8]
for group_m in [8]
]
# these are only used in tuned_mm when AutoHeuristic is enabled
# the idea is that when AutoHeuristic collects data to learn a heuristic, more configs are autotuned
# when the learned heuristic is used, the learned heuristic reduces the number of configs down to 10
# which saves compilation time (since less configs are autotuned) and potentially increase performance
# because the learned heuristic might predict a config that is not part mm_configs
self.extra_mm_configs: list[BaseConfig] = [
GemmConfig(16, 32, 16, 3, 2),
GemmConfig(16, 32, 32, 4, 2),
GemmConfig(16, 32, 32, 5, 2),
GemmConfig(64, 64, 128, 3, 4),
GemmConfig(128, 64, 32, 2, 2),
GemmConfig(128, 64, 64, 3, 8),
GemmConfig(128, 64, 128, 4, 8),
GemmConfig(128, 128, 32, 4, 4),
GemmConfig(128, 128, 64, 3, 8),
GemmConfig(128, 128, 64, 5, 4),
]
self.int8_mm_configs: list[BaseConfig] = [
GemmConfig(64, 64, 32, 2, 4),
GemmConfig(64, 128, 32, 3, 4),
GemmConfig(128, 64, 32, 3, 4),
GemmConfig(64, 128, 32, 4, 8),
GemmConfig(128, 64, 32, 4, 8),
GemmConfig(64, 32, 32, 5, 8),
GemmConfig(32, 64, 32, 5, 8),
GemmConfig(128, 128, 32, 2, 8),
GemmConfig(64, 64, 64, 3, 8),
GemmConfig(128, 256, 128, 3, 8),
GemmConfig(256, 128, 128, 3, 8),
]
self.mixed_mm_configs: list[BaseConfig] = [
GemmConfig(16, 128, 256, 3, 4),
GemmConfig(16, 128, 256, 5, 8),
]
self.persistent_mm_configs: list[BaseConfig] = [
GemmConfig(128, 256, 64, 3, 8),
GemmConfig(128, 128, 64, 3, 8),
GemmConfig(128, 128, 128, 3, 8),
GemmConfig(128, 128, 128, 3, 4),
GemmConfig(128, 128, 64, 4, 8),
GemmConfig(128, 128, 64, 5, 8),
GemmConfig(256, 128, 64, 4, 8),
GemmConfig(128, 128, 64, 5, 4),
]
self.scaled_mm_configs: list[BaseConfig] = [
GemmConfig(128, 256, 32, 3, 8),
GemmConfig(256, 128, 32, 3, 8),
GemmConfig(256, 64, 32, 4, 4),
GemmConfig(64, 256, 32, 4, 4),
GemmConfig(128, 128, 32, 4, 4),
GemmConfig(128, 64, 32, 4, 4),
GemmConfig(64, 128, 32, 4, 4),
GemmConfig(128, 32, 32, 4, 4),
GemmConfig(64, 32, 32, 5, 2),
GemmConfig(256, 128, 128, 3, 8),
GemmConfig(256, 64, 128, 4, 4),
GemmConfig(64, 256, 128, 4, 4),
GemmConfig(128, 128, 128, 4, 4),
GemmConfig(128, 64, 64, 4, 4),
GemmConfig(64, 128, 64, 4, 4),
GemmConfig(128, 32, 64, 4, 4),
GemmConfig(64, 32, 64, 5, 2),
GemmConfig(16, 32, 32, 2, 2),
GemmConfig(16, 64, 32, 2, 2),
GemmConfig(16, 128, 32, 2, 4),
GemmConfig(16, 256, 32, 2, 4),
GemmConfig(16, 32, 64, 2, 2),
GemmConfig(16, 64, 64, 2, 2),
GemmConfig(16, 128, 64, 2, 4),
GemmConfig(16, 256, 64, 2, 4),
GemmConfig(32, 32, 32, 2, 2),
GemmConfig(32, 64, 32, 2, 2),
GemmConfig(32, 128, 32, 2, 4),
GemmConfig(32, 256, 32, 2, 4),
GemmConfig(32, 32, 64, 2, 2),
GemmConfig(32, 64, 64, 2, 2),
GemmConfig(32, 128, 64, 2, 4),
GemmConfig(32, 256, 64, 2, 4),
GemmConfig(16, 32, 32, 3, 2),
GemmConfig(16, 64, 32, 3, 2),
GemmConfig(16, 128, 32, 3, 4),
GemmConfig(16, 256, 32, 3, 4),
GemmConfig(16, 32, 64, 3, 2),
GemmConfig(16, 64, 64, 3, 2),
GemmConfig(16, 128, 64, 3, 4),
GemmConfig(16, 256, 64, 3, 4),
GemmConfig(32, 32, 32, 3, 2),
GemmConfig(32, 64, 32, 3, 2),
GemmConfig(32, 128, 32, 3, 4),
GemmConfig(32, 256, 32, 3, 4),
GemmConfig(32, 32, 64, 3, 2),
GemmConfig(32, 64, 64, 3, 2),
GemmConfig(32, 128, 64, 3, 4),
GemmConfig(32, 256, 64, 3, 4),
GemmConfig(16, 32, 32, 4, 2),
GemmConfig(16, 64, 32, 4, 2),
GemmConfig(16, 128, 32, 4, 4),
GemmConfig(16, 256, 32, 4, 4),
GemmConfig(16, 32, 64, 4, 2),
GemmConfig(16, 64, 64, 4, 2),
GemmConfig(16, 128, 64, 4, 4),
GemmConfig(16, 256, 64, 4, 4),
GemmConfig(32, 32, 32, 4, 2),
GemmConfig(32, 64, 32, 4, 2),
GemmConfig(32, 128, 32, 4, 4),
GemmConfig(32, 256, 32, 4, 4),
GemmConfig(32, 32, 64, 4, 2),
GemmConfig(32, 64, 64, 4, 2),
GemmConfig(32, 128, 64, 4, 4),
GemmConfig(32, 256, 64, 4, 4),
GemmConfig(16, 32, 32, 5, 2),
GemmConfig(16, 64, 32, 5, 2),
GemmConfig(16, 128, 32, 5, 4),
GemmConfig(16, 256, 32, 5, 4),
GemmConfig(16, 32, 64, 5, 2),
GemmConfig(16, 64, 64, 5, 2),
GemmConfig(16, 128, 64, 5, 4),
GemmConfig(16, 256, 64, 5, 4),
GemmConfig(32, 32, 32, 5, 2),
GemmConfig(32, 64, 32, 5, 2),
GemmConfig(32, 128, 32, 5, 4),
GemmConfig(32, 256, 32, 5, 4),
GemmConfig(32, 32, 64, 5, 2),
GemmConfig(32, 64, 64, 5, 2),
GemmConfig(32, 128, 64, 5, 4),
GemmConfig(32, 256, 64, 5, 4),
GemmConfig(16, 32, 32, 6, 2),
GemmConfig(16, 64, 32, 6, 2),
GemmConfig(16, 128, 32, 6, 4),
GemmConfig(16, 256, 32, 6, 4),
GemmConfig(16, 32, 64, 6, 2),
GemmConfig(16, 64, 64, 6, 2),
GemmConfig(16, 128, 64, 6, 4),
GemmConfig(16, 256, 64, 6, 4),
GemmConfig(32, 32, 32, 6, 2),
GemmConfig(32, 64, 32, 6, 2),
GemmConfig(32, 128, 32, 6, 4),
GemmConfig(32, 256, 32, 6, 4),
GemmConfig(32, 32, 64, 6, 2),
GemmConfig(32, 64, 64, 6, 2),
GemmConfig(32, 128, 64, 6, 4),
GemmConfig(32, 256, 64, 6, 4),
]
self.scaled_persistent_mm_configs: list[BaseConfig] = [
GemmConfig(128, 128, 64, 3, 8),
GemmConfig(128, 128, 128, 3, 8),
GemmConfig(128, 128, 128, 4, 8),
GemmConfig(128, 128, 128, 4, 4),
GemmConfig(128, 128, 128, 3, 4),
GemmConfig(128, 128, 128, 5, 4),
GemmConfig(128, 128, 128, 5, 8),
GemmConfig(128, 128, 128, 6, 8),
GemmConfig(128, 128, 64, 4, 8),
]
# TODO: Unify with other gemm patterns, mm_plus_mm currently follows
# slightly different pattern than rest
self.mm_plus_mm_configs: list[BaseConfig] = [
GemmConfig(64, 64, 32, 2, 4),
GemmConfig(64, 64, 32, 3, 8),
GemmConfig(64, 64, 32, 4, 16),
GemmConfig(64, 32, 32, 4, 8),
GemmConfig(32, 64, 32, 4, 8),
GemmConfig(128, 128, 32, 1, 8),
GemmConfig(64, 64, 64, 1, 8),
GemmConfig(32, 32, 128, 1, 8),
GemmConfig(64, 64, 16, 2, 4),
GemmConfig(32, 32, 16, 1, 2),
]
self.conv_configs: list[BaseConfig] = [
ConvConfig(64, 256, 16, 2, 4),
ConvConfig(256, 64, 16, 2, 4),
ConvConfig(1024, 16, 16, 1, 8),
ConvConfig(128, 128, 32, 2, 8),
ConvConfig(64, 64, 32, 2, 4),
ConvConfig(64, 256, 32, 2, 8),
ConvConfig(256, 64, 32, 2, 8),
]
def _finalize_mm_configs(
self,
configs: list[BaseConfig],
) -> Generator[TritonConfig, None, None]:
"""
Finalizes configs after scaling, applying additional constraints.
"""
used: OrderedSet[tuple[int, ...]] = OrderedSet()
max_mm_configs = config.test_configs.max_mm_configs
for conf in configs:
# Each warp computes a 16x16 tile = 256 elements
num_warps = min(conf.num_warps, conf.block_m * conf.block_n // 256)
# Construct key for finding duplicate configs
key: tuple[int, ...] = (
conf.block_m,
conf.block_n,
conf.block_k,
conf.num_stages,
num_warps,
)
# Check if gemm specific arg exists - add to key if does
group_m = getattr(conf, "group_m", None)
if group_m is not None:
key += (group_m,)
if key not in used and (
max_mm_configs is None or len(used) < max_mm_configs
):
used.add(key)
kwargs = {
"BLOCK_M": conf.block_m,
"BLOCK_N": conf.block_n,
"BLOCK_K": conf.block_k,
"num_stages": conf.num_stages,
"num_warps": num_warps,
}
if group_m is not None:
kwargs["GROUP_M"] = group_m
yield self.triton_config(**kwargs)
def _scale_mm_configs(
self,
m: int,
n: int,
k: int,
configs: list[BaseConfig],
scale: float,
has_int8_tensor: bool,
exclude: Callable[[int, int, int], bool],
) -> list[BaseConfig]:
"""
Scales and filters matrix multiplication configs based on input size.
"""
from .runtime.runtime_utils import next_power_of_2
min_block_size = 16
min_block_size_k = 32 if has_int8_tensor else 16
m = max(
next_power_of_2(
V.graph.sizevars.size_hint(
m,
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
)
),
min_block_size,
)
n = max(
next_power_of_2(
V.graph.sizevars.size_hint(
n,
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
)
),
min_block_size,
)
k = max(
next_power_of_2(
V.graph.sizevars.size_hint(
k,
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
)
),
min_block_size_k,
)
scaled_configs = []
for c in configs:
scaled_config = dataclasses.replace(
c,
block_m=max(min(int(c.block_m * scale), m), min_block_size),
block_n=max(min(int(c.block_n * scale), n), min_block_size),
block_k=max(min(int(c.block_k * scale), k), min_block_size_k),
)
if not exclude(
scaled_config.block_m, scaled_config.block_n, scaled_config.block_k
):
scaled_configs.append(scaled_config)
return scaled_configs
def preprocess_mm_configs(
self,
m: int,
n: int,
k: int,
configs: list[BaseConfig],
has_int8_tensor: bool = False,
scale: int = 1,
exclude: Callable[[int, int, int], bool] = lambda m, n, k: False,
) -> Generator[TritonConfig, None, None]:
scaled_configs = self._scale_mm_configs(
m, n, k, configs, scale, has_int8_tensor, exclude
)
return self._finalize_mm_configs(scaled_configs)
def triton_config(
self, num_stages: int, num_warps: int, **kwargs: Any
) -> TritonConfig:
from triton import Config as TritonConfig # type: ignore[attr-defined]
return TritonConfig(kwargs, num_stages=num_stages, num_warps=num_warps)
def get_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self.preprocess_mm_configs, configs=self.mm_configs)
def get_exhaustive_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self.preprocess_mm_configs, configs=self.exhaustive_configs)
def get_extra_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self.preprocess_mm_configs, configs=self.extra_mm_configs)
def get_int8_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self.preprocess_mm_configs, configs=self.int8_mm_configs)
def get_mixed_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
mm_configs = (
self.mm_configs + self.mixed_mm_configs
if config.max_autotune_gemm_search_space == "EXHAUSTIVE"
else self.mm_configs
)
return partial(self.preprocess_mm_configs, configs=mm_configs)
def get_persistent_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self.preprocess_mm_configs, configs=self.persistent_mm_configs)
def get_scaled_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self.preprocess_mm_configs, configs=self.scaled_mm_configs)
def get_scaled_persistent_mm_configs(
self,
) -> partial[Generator[TritonConfig, None, None]]:
return partial(
self.preprocess_mm_configs, configs=self.scaled_persistent_mm_configs
)
def get_mm_plus_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self._finalize_mm_configs, configs=self.mm_plus_mm_configs)
def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]:
return partial(self.preprocess_mm_configs, configs=self.conv_configs)
class CPUConfigHeuristic(BaseConfigHeuristic):
pass
class CUDAConfigHeuristic(BaseConfigHeuristic):
pass
class ROCmConfigHeuristic(BaseConfigHeuristic):
def __init__(self) -> None:
super().__init__()
self.default_num_stages = get_backend_num_stages()
self.mm_configs: list[BaseConfig] = [
ROCmGemmConfig(
16, 16, 256, self.default_num_stages, 4, group_m=4, waves_per_eu=2
),
ROCmGemmConfig(32, 16, 256, self.default_num_stages, 4, group_m=4),
ROCmGemmConfig(
32, 32, 16, self.default_num_stages, 4, group_m=8, waves_per_eu=2
),
ROCmGemmConfig(32, 32, 128, self.default_num_stages, 4, group_m=8),
ROCmGemmConfig(32, 64, 64, self.default_num_stages, 4, group_m=8),
ROCmGemmConfig(
64, 16, 128, self.default_num_stages, 4, group_m=8, waves_per_eu=2
),
ROCmGemmConfig(64, 32, 32, self.default_num_stages, 4, group_m=8),
ROCmGemmConfig(64, 32, 64, self.default_num_stages, 4, group_m=8),
ROCmGemmConfig(64, 32, 64, self.default_num_stages, 8, group_m=8),
ROCmGemmConfig(64, 32, 128, self.default_num_stages, 4, group_m=8),
ROCmGemmConfig(64, 64, 16, self.default_num_stages, 4, group_m=8),
ROCmGemmConfig(64, 64, 64, self.default_num_stages, 4, group_m=4),
ROCmGemmConfig(64, 64, 128, self.default_num_stages, 8, group_m=16),
ROCmGemmConfig(64, 64, 256, self.default_num_stages, 8, group_m=4),
ROCmGemmConfig(
64, 128, 32, self.default_num_stages, 4, group_m=4, waves_per_eu=2
),
ROCmGemmConfig(64, 128, 32, self.default_num_stages, 8, group_m=8),
ROCmGemmConfig(64, 128, 64, self.default_num_stages, 8, group_m=4),
ROCmGemmConfig(64, 128, 128, self.default_num_stages, 8, group_m=4),
ROCmGemmConfig(128, 32, 32, self.default_num_stages, 4, group_m=8),
ROCmGemmConfig(128, 32, 64, self.default_num_stages, 4, group_m=8),
ROCmGemmConfig(
128, 64, 32, self.default_num_stages, 4, group_m=8, waves_per_eu=2
),
ROCmGemmConfig(128, 64, 64, self.default_num_stages, 4, group_m=16),
ROCmGemmConfig(128, 64, 128, self.default_num_stages, 8, group_m=4),
ROCmGemmConfig(
128, 128, 32, self.default_num_stages, 4, group_m=16, waves_per_eu=2
),
ROCmGemmConfig(128, 128, 32, self.default_num_stages, 8, group_m=16),
ROCmGemmConfig(
128, 128, 32, self.default_num_stages, 8, group_m=16, waves_per_eu=2
),
ROCmGemmConfig(128, 128, 64, self.default_num_stages, 4, group_m=16),
ROCmGemmConfig(128, 128, 64, self.default_num_stages, 8, group_m=8),
ROCmGemmConfig(128, 128, 128, self.default_num_stages, 8, group_m=16),
ROCmGemmConfig(
128, 256, 32, self.default_num_stages, 4, group_m=16, waves_per_eu=2
),
ROCmGemmConfig(128, 256, 64, self.default_num_stages, 8, group_m=4),
ROCmGemmConfig(256, 64, 64, self.default_num_stages, 8, group_m=4),
ROCmGemmConfig(
256, 128, 32, self.default_num_stages, 4, group_m=4, waves_per_eu=2
),
ROCmGemmConfig(256, 128, 32, self.default_num_stages, 8, group_m=16),
ROCmGemmConfig(256, 128, 64, self.default_num_stages, 8, group_m=4),
ROCmGemmConfig(256, 256, 64, self.default_num_stages, 8, group_m=4),
]
# Exhaustive search for mm configs
self.exhaustive_configs: list[BaseConfig] = [
ROCmGemmConfig(
BLOCK_M,
BLOCK_N,
BLOCK_K,
num_stages,
num_warps,
group_m,
matrix_instr_nonkdim,
waves_per_eu,
kpack,
)
for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
[16, 32, 64, 128, 256], repeat=3
)
for num_stages in [1, self.default_num_stages]
for num_warps in [4, 8]
for group_m in [4, 8, 16]
for matrix_instr_nonkdim in [0, 16]
for waves_per_eu in [0, 2]
for kpack in [2]
]
def _filter_configs(
self, configs: list[BaseConfig], new_num_stages: int
) -> list[BaseConfig]:
# TODO: _filter_configs can be removed once backend specific configs are added
# for all methods
for c in configs:
c.num_stages = self.default_num_stages
return configs
def _finalize_mm_configs(
self,
configs: list[BaseConfig],
) -> Generator[TritonConfig, None, None]:
"""
Finalizes configs after scaling, applying additional constraints.
"""
used: OrderedSet[tuple[int, ...]] = OrderedSet()
max_mm_configs = config.test_configs.max_mm_configs
for conf in configs:
# Each warp computes a 16x16 tile = 256 elements
conf.num_warps = min(conf.num_warps, conf.block_m * conf.block_n // 256)
# Defaults for AMD triton backend kern args if not set
matrix_instr_nonkdim = getattr(conf, "matrix_instr_nonkdim", 16)
waves_per_eu = getattr(conf, "waves_per_eu", 0)
kpack = getattr(conf, "kpack", 2)
if matrix_instr_nonkdim != 0 and (
conf.block_m % matrix_instr_nonkdim != 0
or conf.block_n % matrix_instr_nonkdim != 0
):
# block_m and block_n must be a multiple of matrix_instr_nonkdim
continue
# Construct key for finding duplicate configs
key: tuple[int, ...] = (
conf.block_m,
conf.block_n,
conf.block_k,
conf.num_stages,
conf.num_warps,
waves_per_eu,
matrix_instr_nonkdim,
kpack,
)
# Check if gemm specific arg exists - add to key if does
group_m = getattr(conf, "group_m", None)
if group_m is not None:
key += (group_m,)
if waves_per_eu != 0:
waves_per_eu = int(8 // conf.num_warps)
if key not in used and (
max_mm_configs is None or len(used) < max_mm_configs
):
used.add(key)
kwargs = {
"BLOCK_M": conf.block_m,
"BLOCK_N": conf.block_n,
"BLOCK_K": conf.block_k,
"num_stages": conf.num_stages,
"num_warps": conf.num_warps,
"matrix_instr_nonkdim": matrix_instr_nonkdim,
"waves_per_eu": waves_per_eu,
"kpack": kpack,
}
if group_m is not None:
kwargs["GROUP_M"] = group_m
yield self.triton_config(**kwargs)
def get_extra_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.extra_mm_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_int8_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.int8_mm_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_mixed_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
mm_configs = (
self.mm_configs + self.mixed_mm_configs
if config.max_autotune_gemm_search_space == "EXHAUSTIVE"
else self.mm_configs
)
filtered_configs = self._filter_configs(mm_configs, self.default_num_stages)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_persistent_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.persistent_mm_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_scaled_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.scaled_mm_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_scaled_persistent_mm_configs(
self,
) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.scaled_persistent_mm_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
def get_mm_plus_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(self.mm_plus_mm_configs, 1)
return partial(self._finalize_mm_configs, configs=filtered_configs)
def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.conv_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
class XPUConfigHeuristic(BaseConfigHeuristic):
pass