[ROCm] Introduce AMD specific inductor gemm tuning (#147315)

Replaces https://github.com/pytorch/pytorch/pull/143286

Adds ROCm specific MM configs for max-autotune incorporating ROCm specific triton tuning kernargs such as waves_per_eu, kpack, matrix_instr_nonkdim. This PR also introduces behavior to allow tuning for GROUP_M in triton gemm case.

Dynamo huggingface inference benchmarks:
`TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS="TRITON" python huggingface.py --performance --inference --bfloat16 --backend=inductor`

GEOMEAN speedup (before): | 1.35x
GEOMEAN speedup (after): | 1.42x

name | Eager - abs latency | old - abs_latency | old - speedup | new - abs_latency | new - speedup
-- | -- | -- | -- | -- | --
AlbertForMaskedLM | 26.22 | 26.52 | 98.86% | 24.58 | 106.67%
AlbertForQuestionAnswering | 25.96 | 26.40 | 98.33% | 24.10 | 107.73%
AllenaiLongformerBase | 21.03 | 10.65 | 197.50% | 10.49 | 200.58%
BartForCausalLM | 7.77 | 9.76 | 79.63% | 8.79 | 88.46%
BartForConditionalGeneration | 14.44 | 12.86 | 112.26% | 11.96 | 120.70%
BertForMaskedLM | 8.10 | 8.82 | 91.89% | 8.57 | 94.53%
BertForQuestionAnswering | 6.82 | 7.32 | 93.20% | 7.10 | 96.18%
BlenderbotForCausalLM | 10.97 | 11.39 | 96.34% | 10.10 | 108.65%
BlenderbotSmallForCausalLM | 5.91 | 5.44 | 108.72% | 4.82 | 122.67%
BlenderbotSmallForConditionalGeneration | 12.64 | 9.65 | 130.94% | 9.11 | 138.83%
CamemBert | 8.35 | 9.15 | 91.24% | 8.86 | 94.27%
DebertaForMaskedLM | 10.92 | 6.09 | 179.44% | 5.90 | 185.05%
DebertaForQuestionAnswering | 14.29 | 7.70 | 185.59% | 7.26 | 196.75%
DebertaV2ForMaskedLM | 15.47 | 10.22 | 151.32% | 9.34 | 165.55%
DebertaV2ForQuestionAnswering | 14.98 | 6.11 | 245.28% | 6.28 | 238.40%
DistilBertForMaskedLM | 8.37 | 8.70 | 96.30% | 8.22 | 101.92%
DistilBertForQuestionAnswering | 10.21 | 10.54 | 96.88% | 10.39 | 98.36%
DistillGPT2 | 8.77 | 6.78 | 129.40% | 6.31 | 138.88%
ElectraForCausalLM | 10.32 | 4.70 | 219.45% | 4.60 | 224.29%
ElectraForQuestionAnswering | 11.48 | 5.62 | 204.20% | 5.44 | 210.95%
GPT2ForSequenceClassification | 6.21 | 5.72 | 108.50% | 5.58 | 111.26%
GoogleFnet | 26.51 | 20.81 | 127.37% | 19.91 | 133.11%
LayoutLMForMaskedLM | 12.09 | 7.99 | 151.28% | 7.66 | 157.80%
LayoutLMForSequenceClassification | 10.62 | 6.49 | 163.67% | 6.25 | 169.95%
M2M100ForConditionalGeneration | 14.98 | 10.20 | 146.79% | 9.89 | 151.42%
MBartForCausalLM | 7.67 | 9.78 | 78.44% | 8.87 | 86.55%
MBartForConditionalGeneration | 13.45 | 12.69 | 105.99% | 12.03 | 111.82%
MT5ForConditionalGeneration | 19.96 | 5.32 | 375.37% | 5.08 | 393.01%
MegatronBertForCausalLM | 13.22 | 7.86 | 168.07% | 7.18 | 184.01%
MegatronBertForQuestionAnswering | 15.62 | 11.81 | 132.21% | 11.02 | 141.68%
MobileBertForMaskedLM | 26.63 | 10.82 | 245.99% | 11.95 | 222.73%
MobileBertForQuestionAnswering | 23.53 | 7.55 | 311.51% | 9.53 | 247.03%
OPTForCausalLM | 7.33 | 7.64 | 95.93% | 7.56 | 96.90%
PLBartForCausalLM | 8.73 | 7.63 | 114.40% | 7.37 | 118.58%
PLBartForConditionalGeneration | 10.46 | 8.50 | 122.98% | 8.16 | 128.13%
PegasusForCausalLM | 7.18 | 7.37 | 97.42% | 6.64 | 108.22%
PegasusForConditionalGeneration | 16.47 | 16.66 | 98.87% | 14.18 | 116.13%
RobertaForCausalLM | 10.30 | 9.95 | 103.52% | 9.52 | 108.25%
RobertaForQuestionAnswering | 6.37 | 7.13 | 89.28% | 6.79 | 93.87%
T5ForConditionalGeneration | 12.40 | 6.72 | 184.51% | 6.48 | 191.16%
T5Small | 12.02 | 6.66 | 180.55% | 6.32 | 190.33%
TrOCRForCausalLM | 14.12 | 13.31 | 106.11% | 12.45 | 113.41%
XGLMForCausalLM | 16.48 | 6.23 | 264.52% | 6.35 | 259.51%
XLNetLMHeadModel | 74.87 | 62.23 | 120.32% | 57.95 | 129.19%
YituTechConvBert | 20.21 | 10.50 | 192.48% | 9.97 | 202.72%

We are also seeing improvement ~9% on internal addmm benchmark

This PR will also slightly reduce the compilation time on AMD max-autotune as before this change we assess every config with matrix_instr_nonkdim [0, 16] but we remove this and use 16 for all configs with this update.

No CI to test the max-autotune perf currently but this will be enabled via https://github.com/pytorch/pytorch/pull/148672 after which we can investigate more tuning updates and config pruning

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147315
Approved by: https://github.com/jansel, https://github.com/eellison
This commit is contained in:
Jack Taylor 2025-04-09 14:34:30 +00:00 committed by PyTorch MergeBot
parent 886d9acb0d
commit 2299087220
3 changed files with 416 additions and 278 deletions

View File

@ -72,8 +72,7 @@ def mm_options(config, sym_m, sym_n, sym_k, layout):
not inductor_config.force_same_precision
or ((sym_m % 16) == 0 and (sym_n % 16) == 0 and (sym_k % 8) == 0)
)
return dict(
GROUP_M=8,
options_dict = dict(
EVEN_K=even_k_symbolic,
ALLOW_TF32=allow_tf32,
USE_FAST_ACCUM=False, # Option for _scaled_mm
@ -83,6 +82,13 @@ def mm_options(config, sym_m, sym_n, sym_k, layout):
**config.kwargs,
)
# If GROUP_M not specified then default to 8
if "GROUP_M" not in config.kwargs:
group_m = config.kwargs.get("GROUP_M", 8)
options_dict["GROUP_M"] = group_m
return options_dict
def persistent_mm_options(mat1, mat2):
return dict(

View File

@ -1282,6 +1282,7 @@ class TritonTemplate(KernelTemplate):
),
"num_stages": num_stages,
"num_warps": num_warps,
"GROUP_M": kwargs.get("GROUP_M", -1),
"allow_tf32": str(kwargs.get("ALLOW_TF32", None)),
"acc_type": str(kwargs.get("ACC_TYPE", None)),
},

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import dataclasses
import itertools
from collections import namedtuple
from functools import partial
from threading import Lock
from typing import Any, Callable, TYPE_CHECKING
@ -14,12 +14,59 @@ from .virtualized import V
if TYPE_CHECKING:
from collections.abc import Generator, Sequence
from collections.abc import Generator
from triton import Config as TritonConfig
class BaseConfigSingleton(type):
@dataclasses.dataclass
class BaseConfig:
"""
Base Gemm configuration used for most backends (CPU, CUDA)
"""
block_m: int
block_n: int
block_k: int
num_stages: int
num_warps: int
@dataclasses.dataclass
class GemmConfig(BaseConfig):
"""
Gemm configuration used for most backends (CPU, CUDA)
"""
group_m: int = 8
ConvConfig = BaseConfig
@dataclasses.dataclass
class ROCmGemmConfig(GemmConfig):
"""
ROCm subclass for GEMMs, with AMD backend specific tuneable kernargs
"""
matrix_instr_nonkdim: int = 16
waves_per_eu: int = 0
kpack: int = 2
@dataclasses.dataclass
class ROCmConvConfig(ConvConfig):
"""
ROCm subclass for Conv, with AMD backend specific tuneable kernargs
"""
matrix_instr_nonkdim: int = 16
waves_per_eu: int = 0
kpack: int = 2
class BaseHeuristicSingleton(type):
"""
Thread-safe implementation of single to be used in the config heuristic subclasses
to ensure heavy __init__ calls are not repeatedly run
@ -29,7 +76,7 @@ class BaseConfigSingleton(type):
_lock: Lock = Lock()
def __call__(
cls: BaseConfigSingleton, *args: Any, **kwargs: Any
cls: BaseHeuristicSingleton, *args: Any, **kwargs: Any
) -> BaseConfigHeuristic:
with cls._lock:
if cls not in cls._instances:
@ -38,12 +85,7 @@ class BaseConfigSingleton(type):
return cls._instances[cls]
Config = namedtuple(
"Config", ["block_m", "block_n", "block_k", "num_stages", "num_warps"]
)
class BaseConfigHeuristic(metaclass=BaseConfigSingleton):
class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
"""
Base class for mm_configs, device specific triton kernels config inherit from here
"""
@ -52,36 +94,37 @@ class BaseConfigHeuristic(metaclass=BaseConfigSingleton):
# List of dictionaries to store the kernel configs. Configs that evaluate to true
# will be utilised on the target platform. The configs are as follows:
# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
self.mm_configs = [
Config(32, 32, 16, 1, 2),
Config(32, 32, 128, 2, 4),
Config(32, 64, 32, 5, 8),
Config(64, 32, 32, 5, 8),
Config(64, 32, 128, 5, 4),
Config(64, 64, 16, 2, 4),
Config(64, 64, 32, 2, 4),
Config(64, 64, 64, 3, 8),
Config(64, 64, 128, 5, 4),
Config(64, 128, 32, 3, 4),
Config(64, 128, 32, 4, 8),
Config(64, 128, 64, 3, 4),
Config(64, 128, 128, 4, 4),
Config(128, 64, 32, 3, 4),
Config(128, 64, 32, 4, 8),
Config(128, 128, 32, 2, 8),
Config(128, 128, 32, 3, 4),
Config(128, 128, 64, 3, 4),
Config(128, 128, 64, 5, 8),
self.mm_configs: list[BaseConfig] = [
GemmConfig(32, 32, 16, 1, 2),
GemmConfig(32, 32, 128, 2, 4),
GemmConfig(32, 64, 32, 5, 8),
GemmConfig(64, 32, 32, 5, 8),
GemmConfig(64, 32, 128, 5, 4),
GemmConfig(64, 64, 16, 2, 4),
GemmConfig(64, 64, 32, 2, 4),
GemmConfig(64, 64, 64, 3, 8),
GemmConfig(64, 64, 128, 5, 4),
GemmConfig(64, 128, 32, 3, 4),
GemmConfig(64, 128, 32, 4, 8),
GemmConfig(64, 128, 64, 3, 4),
GemmConfig(64, 128, 128, 4, 4),
GemmConfig(128, 64, 32, 3, 4),
GemmConfig(128, 64, 32, 4, 8),
GemmConfig(128, 128, 32, 2, 8),
GemmConfig(128, 128, 32, 3, 4),
GemmConfig(128, 128, 64, 3, 4),
GemmConfig(128, 128, 64, 5, 8),
]
# Exhaustive search for mm configs
self.exhaustive_configs = [
Config(BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
self.exhaustive_configs: list[BaseConfig] = [
GemmConfig(BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps, group_m)
for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
[16, 32, 64, 128, 256], repeat=3
)
for num_stages in [1, 2, 3, 4, 5]
for num_warps in [2, 4, 8]
for group_m in [8]
]
# these are only used in tuned_mm when AutoHeuristic is enabled
@ -89,220 +132,237 @@ class BaseConfigHeuristic(metaclass=BaseConfigSingleton):
# when the learned heuristic is used, the learned heuristic reduces the number of configs down to 10
# which saves compilation time (since less configs are autotuned) and potentially increase performance
# because the learned heuristic might predict a config that is not part mm_configs
self.extra_mm_configs = [
Config(16, 32, 16, 3, 2),
Config(16, 32, 32, 4, 2),
Config(16, 32, 32, 5, 2),
Config(64, 64, 128, 3, 4),
Config(128, 64, 32, 2, 2),
Config(128, 64, 64, 3, 8),
Config(128, 64, 128, 4, 8),
Config(128, 128, 32, 4, 4),
Config(128, 128, 64, 3, 8),
Config(128, 128, 64, 5, 4),
self.extra_mm_configs: list[BaseConfig] = [
GemmConfig(16, 32, 16, 3, 2),
GemmConfig(16, 32, 32, 4, 2),
GemmConfig(16, 32, 32, 5, 2),
GemmConfig(64, 64, 128, 3, 4),
GemmConfig(128, 64, 32, 2, 2),
GemmConfig(128, 64, 64, 3, 8),
GemmConfig(128, 64, 128, 4, 8),
GemmConfig(128, 128, 32, 4, 4),
GemmConfig(128, 128, 64, 3, 8),
GemmConfig(128, 128, 64, 5, 4),
]
self.int8_mm_configs = [
Config(64, 64, 32, 2, 4),
Config(64, 128, 32, 3, 4),
Config(128, 64, 32, 3, 4),
Config(64, 128, 32, 4, 8),
Config(128, 64, 32, 4, 8),
Config(64, 32, 32, 5, 8),
Config(32, 64, 32, 5, 8),
Config(128, 128, 32, 2, 8),
Config(64, 64, 64, 3, 8),
Config(128, 256, 128, 3, 8),
Config(256, 128, 128, 3, 8),
self.int8_mm_configs: list[BaseConfig] = [
GemmConfig(64, 64, 32, 2, 4),
GemmConfig(64, 128, 32, 3, 4),
GemmConfig(128, 64, 32, 3, 4),
GemmConfig(64, 128, 32, 4, 8),
GemmConfig(128, 64, 32, 4, 8),
GemmConfig(64, 32, 32, 5, 8),
GemmConfig(32, 64, 32, 5, 8),
GemmConfig(128, 128, 32, 2, 8),
GemmConfig(64, 64, 64, 3, 8),
GemmConfig(128, 256, 128, 3, 8),
GemmConfig(256, 128, 128, 3, 8),
]
self.mixed_mm_configs = [
Config(16, 128, 256, 3, 4),
Config(16, 128, 256, 5, 8),
self.mixed_mm_configs: list[BaseConfig] = [
GemmConfig(16, 128, 256, 3, 4),
GemmConfig(16, 128, 256, 5, 8),
]
self.persistent_mm_configs = [
Config(128, 256, 64, 3, 8),
Config(128, 128, 64, 3, 8),
Config(128, 128, 128, 3, 8),
Config(128, 128, 128, 3, 4),
Config(128, 128, 64, 4, 8),
self.persistent_mm_configs: list[BaseConfig] = [
GemmConfig(128, 256, 64, 3, 8),
GemmConfig(128, 128, 64, 3, 8),
GemmConfig(128, 128, 128, 3, 8),
GemmConfig(128, 128, 128, 3, 4),
GemmConfig(128, 128, 64, 4, 8),
]
self.scaled_mm_configs = [
Config(128, 256, 32, 3, 8),
Config(256, 128, 32, 3, 8),
Config(256, 64, 32, 4, 4),
Config(64, 256, 32, 4, 4),
Config(128, 128, 32, 4, 4),
Config(128, 64, 32, 4, 4),
Config(64, 128, 32, 4, 4),
Config(128, 32, 32, 4, 4),
Config(64, 32, 32, 5, 2),
Config(256, 128, 128, 3, 8),
Config(256, 64, 128, 4, 4),
Config(64, 256, 128, 4, 4),
Config(128, 128, 128, 4, 4),
Config(128, 64, 64, 4, 4),
Config(64, 128, 64, 4, 4),
Config(128, 32, 64, 4, 4),
Config(64, 32, 64, 5, 2),
Config(16, 32, 32, 2, 2),
Config(16, 64, 32, 2, 2),
Config(16, 128, 32, 2, 4),
Config(16, 256, 32, 2, 4),
Config(16, 32, 64, 2, 2),
Config(16, 64, 64, 2, 2),
Config(16, 128, 64, 2, 4),
Config(16, 256, 64, 2, 4),
Config(32, 32, 32, 2, 2),
Config(32, 64, 32, 2, 2),
Config(32, 128, 32, 2, 4),
Config(32, 256, 32, 2, 4),
Config(32, 32, 64, 2, 2),
Config(32, 64, 64, 2, 2),
Config(32, 128, 64, 2, 4),
Config(32, 256, 64, 2, 4),
Config(16, 32, 32, 3, 2),
Config(16, 64, 32, 3, 2),
Config(16, 128, 32, 3, 4),
Config(16, 256, 32, 3, 4),
Config(16, 32, 64, 3, 2),
Config(16, 64, 64, 3, 2),
Config(16, 128, 64, 3, 4),
Config(16, 256, 64, 3, 4),
Config(32, 32, 32, 3, 2),
Config(32, 64, 32, 3, 2),
Config(32, 128, 32, 3, 4),
Config(32, 256, 32, 3, 4),
Config(32, 32, 64, 3, 2),
Config(32, 64, 64, 3, 2),
Config(32, 128, 64, 3, 4),
Config(32, 256, 64, 3, 4),
Config(16, 32, 32, 4, 2),
Config(16, 64, 32, 4, 2),
Config(16, 128, 32, 4, 4),
Config(16, 256, 32, 4, 4),
Config(16, 32, 64, 4, 2),
Config(16, 64, 64, 4, 2),
Config(16, 128, 64, 4, 4),
Config(16, 256, 64, 4, 4),
Config(32, 32, 32, 4, 2),
Config(32, 64, 32, 4, 2),
Config(32, 128, 32, 4, 4),
Config(32, 256, 32, 4, 4),
Config(32, 32, 64, 4, 2),
Config(32, 64, 64, 4, 2),
Config(32, 128, 64, 4, 4),
Config(32, 256, 64, 4, 4),
Config(16, 32, 32, 5, 2),
Config(16, 64, 32, 5, 2),
Config(16, 128, 32, 5, 4),
Config(16, 256, 32, 5, 4),
Config(16, 32, 64, 5, 2),
Config(16, 64, 64, 5, 2),
Config(16, 128, 64, 5, 4),
Config(16, 256, 64, 5, 4),
Config(32, 32, 32, 5, 2),
Config(32, 64, 32, 5, 2),
Config(32, 128, 32, 5, 4),
Config(32, 256, 32, 5, 4),
Config(32, 32, 64, 5, 2),
Config(32, 64, 64, 5, 2),
Config(32, 128, 64, 5, 4),
Config(32, 256, 64, 5, 4),
Config(16, 32, 32, 6, 2),
Config(16, 64, 32, 6, 2),
Config(16, 128, 32, 6, 4),
Config(16, 256, 32, 6, 4),
Config(16, 32, 64, 6, 2),
Config(16, 64, 64, 6, 2),
Config(16, 128, 64, 6, 4),
Config(16, 256, 64, 6, 4),
Config(32, 32, 32, 6, 2),
Config(32, 64, 32, 6, 2),
Config(32, 128, 32, 6, 4),
Config(32, 256, 32, 6, 4),
Config(32, 32, 64, 6, 2),
Config(32, 64, 64, 6, 2),
Config(32, 128, 64, 6, 4),
Config(32, 256, 64, 6, 4),
self.scaled_mm_configs: list[BaseConfig] = [
GemmConfig(128, 256, 32, 3, 8),
GemmConfig(256, 128, 32, 3, 8),
GemmConfig(256, 64, 32, 4, 4),
GemmConfig(64, 256, 32, 4, 4),
GemmConfig(128, 128, 32, 4, 4),
GemmConfig(128, 64, 32, 4, 4),
GemmConfig(64, 128, 32, 4, 4),
GemmConfig(128, 32, 32, 4, 4),
GemmConfig(64, 32, 32, 5, 2),
GemmConfig(256, 128, 128, 3, 8),
GemmConfig(256, 64, 128, 4, 4),
GemmConfig(64, 256, 128, 4, 4),
GemmConfig(128, 128, 128, 4, 4),
GemmConfig(128, 64, 64, 4, 4),
GemmConfig(64, 128, 64, 4, 4),
GemmConfig(128, 32, 64, 4, 4),
GemmConfig(64, 32, 64, 5, 2),
GemmConfig(16, 32, 32, 2, 2),
GemmConfig(16, 64, 32, 2, 2),
GemmConfig(16, 128, 32, 2, 4),
GemmConfig(16, 256, 32, 2, 4),
GemmConfig(16, 32, 64, 2, 2),
GemmConfig(16, 64, 64, 2, 2),
GemmConfig(16, 128, 64, 2, 4),
GemmConfig(16, 256, 64, 2, 4),
GemmConfig(32, 32, 32, 2, 2),
GemmConfig(32, 64, 32, 2, 2),
GemmConfig(32, 128, 32, 2, 4),
GemmConfig(32, 256, 32, 2, 4),
GemmConfig(32, 32, 64, 2, 2),
GemmConfig(32, 64, 64, 2, 2),
GemmConfig(32, 128, 64, 2, 4),
GemmConfig(32, 256, 64, 2, 4),
GemmConfig(16, 32, 32, 3, 2),
GemmConfig(16, 64, 32, 3, 2),
GemmConfig(16, 128, 32, 3, 4),
GemmConfig(16, 256, 32, 3, 4),
GemmConfig(16, 32, 64, 3, 2),
GemmConfig(16, 64, 64, 3, 2),
GemmConfig(16, 128, 64, 3, 4),
GemmConfig(16, 256, 64, 3, 4),
GemmConfig(32, 32, 32, 3, 2),
GemmConfig(32, 64, 32, 3, 2),
GemmConfig(32, 128, 32, 3, 4),
GemmConfig(32, 256, 32, 3, 4),
GemmConfig(32, 32, 64, 3, 2),
GemmConfig(32, 64, 64, 3, 2),
GemmConfig(32, 128, 64, 3, 4),
GemmConfig(32, 256, 64, 3, 4),
GemmConfig(16, 32, 32, 4, 2),
GemmConfig(16, 64, 32, 4, 2),
GemmConfig(16, 128, 32, 4, 4),
GemmConfig(16, 256, 32, 4, 4),
GemmConfig(16, 32, 64, 4, 2),
GemmConfig(16, 64, 64, 4, 2),
GemmConfig(16, 128, 64, 4, 4),
GemmConfig(16, 256, 64, 4, 4),
GemmConfig(32, 32, 32, 4, 2),
GemmConfig(32, 64, 32, 4, 2),
GemmConfig(32, 128, 32, 4, 4),
GemmConfig(32, 256, 32, 4, 4),
GemmConfig(32, 32, 64, 4, 2),
GemmConfig(32, 64, 64, 4, 2),
GemmConfig(32, 128, 64, 4, 4),
GemmConfig(32, 256, 64, 4, 4),
GemmConfig(16, 32, 32, 5, 2),
GemmConfig(16, 64, 32, 5, 2),
GemmConfig(16, 128, 32, 5, 4),
GemmConfig(16, 256, 32, 5, 4),
GemmConfig(16, 32, 64, 5, 2),
GemmConfig(16, 64, 64, 5, 2),
GemmConfig(16, 128, 64, 5, 4),
GemmConfig(16, 256, 64, 5, 4),
GemmConfig(32, 32, 32, 5, 2),
GemmConfig(32, 64, 32, 5, 2),
GemmConfig(32, 128, 32, 5, 4),
GemmConfig(32, 256, 32, 5, 4),
GemmConfig(32, 32, 64, 5, 2),
GemmConfig(32, 64, 64, 5, 2),
GemmConfig(32, 128, 64, 5, 4),
GemmConfig(32, 256, 64, 5, 4),
GemmConfig(16, 32, 32, 6, 2),
GemmConfig(16, 64, 32, 6, 2),
GemmConfig(16, 128, 32, 6, 4),
GemmConfig(16, 256, 32, 6, 4),
GemmConfig(16, 32, 64, 6, 2),
GemmConfig(16, 64, 64, 6, 2),
GemmConfig(16, 128, 64, 6, 4),
GemmConfig(16, 256, 64, 6, 4),
GemmConfig(32, 32, 32, 6, 2),
GemmConfig(32, 64, 32, 6, 2),
GemmConfig(32, 128, 32, 6, 4),
GemmConfig(32, 256, 32, 6, 4),
GemmConfig(32, 32, 64, 6, 2),
GemmConfig(32, 64, 64, 6, 2),
GemmConfig(32, 128, 64, 6, 4),
GemmConfig(32, 256, 64, 6, 4),
]
self.scaled_persistent_mm_configs = [
Config(128, 128, 64, 3, 8),
Config(128, 128, 128, 3, 8),
Config(128, 128, 128, 4, 8),
Config(128, 128, 128, 4, 4),
Config(128, 128, 128, 3, 4),
Config(128, 128, 128, 5, 4),
Config(128, 128, 128, 5, 8),
Config(128, 128, 128, 6, 8),
Config(128, 128, 64, 4, 8),
self.scaled_persistent_mm_configs: list[BaseConfig] = [
GemmConfig(128, 128, 64, 3, 8),
GemmConfig(128, 128, 128, 3, 8),
GemmConfig(128, 128, 128, 4, 8),
GemmConfig(128, 128, 128, 4, 4),
GemmConfig(128, 128, 128, 3, 4),
GemmConfig(128, 128, 128, 5, 4),
GemmConfig(128, 128, 128, 5, 8),
GemmConfig(128, 128, 128, 6, 8),
GemmConfig(128, 128, 64, 4, 8),
]
# TODO: Unify with other gemm patterns, mm_plus_mm currently follows
# slightly different pattern than rest
self.mm_plus_mm_configs = [
Config(64, 64, 32, 2, 4),
Config(64, 64, 32, 3, 8),
Config(64, 64, 32, 4, 16),
Config(64, 32, 32, 4, 8),
Config(32, 64, 32, 4, 8),
Config(128, 128, 32, 1, 8),
Config(64, 64, 64, 1, 8),
Config(32, 32, 128, 1, 8),
Config(64, 64, 16, 2, 4),
Config(32, 32, 16, 1, 2),
self.mm_plus_mm_configs: list[BaseConfig] = [
GemmConfig(64, 64, 32, 2, 4),
GemmConfig(64, 64, 32, 3, 8),
GemmConfig(64, 64, 32, 4, 16),
GemmConfig(64, 32, 32, 4, 8),
GemmConfig(32, 64, 32, 4, 8),
GemmConfig(128, 128, 32, 1, 8),
GemmConfig(64, 64, 64, 1, 8),
GemmConfig(32, 32, 128, 1, 8),
GemmConfig(64, 64, 16, 2, 4),
GemmConfig(32, 32, 16, 1, 2),
]
self.conv_configs = [
Config(64, 256, 16, 2, 4),
Config(256, 64, 16, 2, 4),
Config(1024, 16, 16, 1, 8),
Config(128, 128, 32, 2, 8),
Config(64, 64, 32, 2, 4),
Config(64, 256, 32, 2, 8),
Config(256, 64, 32, 2, 8),
self.conv_configs: list[BaseConfig] = [
ConvConfig(64, 256, 16, 2, 4),
ConvConfig(256, 64, 16, 2, 4),
ConvConfig(1024, 16, 16, 1, 8),
ConvConfig(128, 128, 32, 2, 8),
ConvConfig(64, 64, 32, 2, 4),
ConvConfig(64, 256, 32, 2, 8),
ConvConfig(256, 64, 32, 2, 8),
]
def _finalize_mm_configs(
self,
configs: list[Config],
configs: list[BaseConfig],
) -> Generator[TritonConfig, None, None]:
"""
Finalizes configs after scaling, applying additional constraints.
"""
used = OrderedSet[Config]()
used: OrderedSet[tuple[int, ...]] = OrderedSet()
max_mm_configs = config.test_configs.max_mm_configs
for block_m, block_n, block_k, num_stages, num_warps in configs:
for conf in configs:
# Each warp computes a 16x16 tile = 256 elements
num_warps = min(num_warps, block_m * block_n // 256)
num_warps = min(conf.num_warps, conf.block_m * conf.block_n // 256)
if (
Config(block_m, block_n, block_k, num_stages, num_warps)
) not in used and (max_mm_configs is None or len(used) < max_mm_configs):
used.add(Config(block_m, block_n, block_k, num_stages, num_warps))
yield self.triton_config(
BLOCK_M=block_m,
BLOCK_N=block_n,
BLOCK_K=block_k,
num_stages=num_stages,
num_warps=num_warps,
)
# Construct key for finding duplicate configs
key: tuple[int, ...] = (
conf.block_m,
conf.block_n,
conf.block_k,
conf.num_stages,
num_warps,
)
# Check if gemm specific arg exists - add to key if does
group_m = getattr(conf, "group_m", None)
if group_m is not None:
key += (group_m,)
if key not in used and (
max_mm_configs is None or len(used) < max_mm_configs
):
used.add(key)
kwargs = {
"BLOCK_M": conf.block_m,
"BLOCK_N": conf.block_n,
"BLOCK_K": conf.block_k,
"num_stages": conf.num_stages,
"num_warps": num_warps,
}
if group_m is not None:
kwargs["GROUP_M"] = group_m
yield self.triton_config(**kwargs)
def _scale_mm_configs(
self,
m: int,
n: int,
k: int,
configs: Sequence[Config],
configs: list[BaseConfig],
scale: float,
has_int8_tensor: bool,
exclude: Callable[[int, int, int], bool],
) -> list[Config]:
) -> list[BaseConfig]:
"""
Scales and filters matrix multiplication configs based on input size.
"""
@ -341,7 +401,8 @@ class BaseConfigHeuristic(metaclass=BaseConfigSingleton):
scaled_configs = []
for c in configs:
scaled_config = c._replace(
scaled_config = dataclasses.replace(
c,
block_m=max(min(int(c.block_m * scale), m), min_block_size),
block_n=max(min(int(c.block_n * scale), n), min_block_size),
block_k=max(min(int(c.block_k * scale), k), min_block_size_k),
@ -359,7 +420,7 @@ class BaseConfigHeuristic(metaclass=BaseConfigSingleton):
m: int,
n: int,
k: int,
configs: Sequence[Config],
configs: list[BaseConfig],
has_int8_tensor: bool = False,
scale: int = 1,
exclude: Callable[[int, int, int], bool] = lambda m, n, k: False,
@ -430,90 +491,160 @@ class ROCmConfigHeuristic(BaseConfigHeuristic):
self.default_num_stages = get_backend_num_stages()
self.mm_configs: list[BaseConfig] = [
ROCmGemmConfig(
16, 16, 256, self.default_num_stages, 4, group_m=4, waves_per_eu=2
),
ROCmGemmConfig(32, 16, 256, self.default_num_stages, 4, group_m=4),
ROCmGemmConfig(
32, 32, 16, self.default_num_stages, 4, group_m=8, waves_per_eu=2
),
ROCmGemmConfig(32, 32, 128, self.default_num_stages, 4, group_m=8),
ROCmGemmConfig(32, 64, 64, self.default_num_stages, 4, group_m=8),
ROCmGemmConfig(
64, 16, 128, self.default_num_stages, 4, group_m=8, waves_per_eu=2
),
ROCmGemmConfig(64, 32, 32, self.default_num_stages, 4, group_m=8),
ROCmGemmConfig(64, 32, 64, self.default_num_stages, 4, group_m=8),
ROCmGemmConfig(64, 32, 64, self.default_num_stages, 8, group_m=8),
ROCmGemmConfig(64, 32, 128, self.default_num_stages, 4, group_m=8),
ROCmGemmConfig(64, 64, 16, self.default_num_stages, 4, group_m=8),
ROCmGemmConfig(64, 64, 64, self.default_num_stages, 4, group_m=4),
ROCmGemmConfig(64, 64, 128, self.default_num_stages, 8, group_m=16),
ROCmGemmConfig(64, 64, 256, self.default_num_stages, 8, group_m=4),
ROCmGemmConfig(
64, 128, 32, self.default_num_stages, 4, group_m=4, waves_per_eu=2
),
ROCmGemmConfig(64, 128, 32, self.default_num_stages, 8, group_m=8),
ROCmGemmConfig(64, 128, 64, self.default_num_stages, 8, group_m=4),
ROCmGemmConfig(64, 128, 128, self.default_num_stages, 8, group_m=4),
ROCmGemmConfig(128, 32, 32, self.default_num_stages, 4, group_m=8),
ROCmGemmConfig(128, 32, 64, self.default_num_stages, 4, group_m=8),
ROCmGemmConfig(
128, 64, 32, self.default_num_stages, 4, group_m=8, waves_per_eu=2
),
ROCmGemmConfig(128, 64, 64, self.default_num_stages, 4, group_m=16),
ROCmGemmConfig(128, 64, 128, self.default_num_stages, 8, group_m=4),
ROCmGemmConfig(
128, 128, 32, self.default_num_stages, 4, group_m=16, waves_per_eu=2
),
ROCmGemmConfig(128, 128, 32, self.default_num_stages, 8, group_m=16),
ROCmGemmConfig(
128, 128, 32, self.default_num_stages, 8, group_m=16, waves_per_eu=2
),
ROCmGemmConfig(128, 128, 64, self.default_num_stages, 4, group_m=16),
ROCmGemmConfig(128, 128, 64, self.default_num_stages, 8, group_m=8),
ROCmGemmConfig(128, 128, 128, self.default_num_stages, 8, group_m=16),
ROCmGemmConfig(
128, 256, 32, self.default_num_stages, 4, group_m=16, waves_per_eu=2
),
ROCmGemmConfig(128, 256, 64, self.default_num_stages, 8, group_m=4),
ROCmGemmConfig(256, 64, 64, self.default_num_stages, 8, group_m=4),
ROCmGemmConfig(
256, 128, 32, self.default_num_stages, 4, group_m=4, waves_per_eu=2
),
ROCmGemmConfig(256, 128, 32, self.default_num_stages, 8, group_m=16),
ROCmGemmConfig(256, 128, 64, self.default_num_stages, 8, group_m=4),
ROCmGemmConfig(256, 256, 64, self.default_num_stages, 8, group_m=4),
]
# Exhaustive search for mm configs
self.exhaustive_configs = [
Config(BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
self.exhaustive_configs: list[BaseConfig] = [
ROCmGemmConfig(
BLOCK_M,
BLOCK_N,
BLOCK_K,
num_stages,
num_warps,
group_m,
matrix_instr_nonkdim,
waves_per_eu,
kpack,
)
for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
[16, 32, 64, 128, 256], repeat=3
)
for num_stages in [1, self.default_num_stages]
for num_warps in [4, 8]
for group_m in [4, 8, 16]
for matrix_instr_nonkdim in [0, 16]
for waves_per_eu in [0, 2]
for kpack in [2]
]
def _filter_configs(
self, configs: list[Config], new_num_stages: int
) -> list[Config]:
filtered_configs = [
c._replace(num_stages=self.default_num_stages) for c in configs
]
return filtered_configs
self, configs: list[BaseConfig], new_num_stages: int
) -> list[BaseConfig]:
# TODO: _filter_configs can be removed once backend specific configs are added
# for all methods
for c in configs:
c.num_stages = self.default_num_stages
return configs
def _finalize_mm_configs(
self,
configs: list[Config],
configs: list[BaseConfig],
) -> Generator[TritonConfig, None, None]:
used = OrderedSet[tuple[Config, int, int]]()
"""
Finalizes configs after scaling, applying additional constraints.
"""
used: OrderedSet[tuple[int, ...]] = OrderedSet()
max_mm_configs = config.test_configs.max_mm_configs
for block_m, block_n, block_k, num_stages, num_warps in configs:
# each warp computes 16x16 tile = 256
num_warps = min(num_warps, block_m * block_n // 256)
kpack = 2
for matrix_instr_nonkdim in [0, 16]:
if matrix_instr_nonkdim != 0 and (
block_m % matrix_instr_nonkdim != 0
or block_n % matrix_instr_nonkdim != 0
):
# block_m and block_n must be a multiple of matrix_instr_nonkdim
continue
if (
Config(
block_m,
block_n,
block_k,
num_stages,
num_warps,
),
matrix_instr_nonkdim,
kpack,
) not in used and (
max_mm_configs is None or len(used) < max_mm_configs
):
used.add(
(
Config(
block_m,
block_n,
block_k,
num_stages,
num_warps,
),
matrix_instr_nonkdim,
kpack,
)
)
yield self.triton_config(
BLOCK_M=block_m,
BLOCK_N=block_n,
BLOCK_K=block_k,
num_stages=num_stages,
num_warps=num_warps,
matrix_instr_nonkdim=matrix_instr_nonkdim,
kpack=kpack,
)
for conf in configs:
# Each warp computes a 16x16 tile = 256 elements
conf.num_warps = min(conf.num_warps, conf.block_m * conf.block_n // 256)
def get_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.mm_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
# Defaults for AMD triton backend kern args if not set
matrix_instr_nonkdim = getattr(conf, "matrix_instr_nonkdim", 16)
waves_per_eu = getattr(conf, "waves_per_eu", 0)
kpack = getattr(conf, "kpack", 2)
def get_exhaustive_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(
self.exhaustive_configs, self.default_num_stages
)
return partial(self.preprocess_mm_configs, configs=filtered_configs)
if matrix_instr_nonkdim != 0 and (
conf.block_m % matrix_instr_nonkdim != 0
or conf.block_n % matrix_instr_nonkdim != 0
):
# block_m and block_n must be a multiple of matrix_instr_nonkdim
continue
# Construct key for finding duplicate configs
key: tuple[int, ...] = (
conf.block_m,
conf.block_n,
conf.block_k,
conf.num_stages,
conf.num_warps,
waves_per_eu,
matrix_instr_nonkdim,
kpack,
)
# Check if gemm specific arg exists - add to key if does
group_m = getattr(conf, "group_m", None)
if group_m is not None:
key += (group_m,)
if waves_per_eu != 0:
waves_per_eu = int(8 // conf.num_warps)
if key not in used and (
max_mm_configs is None or len(used) < max_mm_configs
):
used.add(key)
kwargs = {
"BLOCK_M": conf.block_m,
"BLOCK_N": conf.block_n,
"BLOCK_K": conf.block_k,
"num_stages": conf.num_stages,
"num_warps": conf.num_warps,
"matrix_instr_nonkdim": matrix_instr_nonkdim,
"waves_per_eu": waves_per_eu,
"kpack": kpack,
}
if group_m is not None:
kwargs["GROUP_M"] = group_m
yield self.triton_config(**kwargs)
def get_extra_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
filtered_configs = self._filter_configs(