mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ROCm] Introduce AMD specific inductor gemm tuning (#147315)
Replaces https://github.com/pytorch/pytorch/pull/143286 Adds ROCm specific MM configs for max-autotune incorporating ROCm specific triton tuning kernargs such as waves_per_eu, kpack, matrix_instr_nonkdim. This PR also introduces behavior to allow tuning for GROUP_M in triton gemm case. Dynamo huggingface inference benchmarks: `TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS="TRITON" python huggingface.py --performance --inference --bfloat16 --backend=inductor` GEOMEAN speedup (before): | 1.35x GEOMEAN speedup (after): | 1.42x name | Eager - abs latency | old - abs_latency | old - speedup | new - abs_latency | new - speedup -- | -- | -- | -- | -- | -- AlbertForMaskedLM | 26.22 | 26.52 | 98.86% | 24.58 | 106.67% AlbertForQuestionAnswering | 25.96 | 26.40 | 98.33% | 24.10 | 107.73% AllenaiLongformerBase | 21.03 | 10.65 | 197.50% | 10.49 | 200.58% BartForCausalLM | 7.77 | 9.76 | 79.63% | 8.79 | 88.46% BartForConditionalGeneration | 14.44 | 12.86 | 112.26% | 11.96 | 120.70% BertForMaskedLM | 8.10 | 8.82 | 91.89% | 8.57 | 94.53% BertForQuestionAnswering | 6.82 | 7.32 | 93.20% | 7.10 | 96.18% BlenderbotForCausalLM | 10.97 | 11.39 | 96.34% | 10.10 | 108.65% BlenderbotSmallForCausalLM | 5.91 | 5.44 | 108.72% | 4.82 | 122.67% BlenderbotSmallForConditionalGeneration | 12.64 | 9.65 | 130.94% | 9.11 | 138.83% CamemBert | 8.35 | 9.15 | 91.24% | 8.86 | 94.27% DebertaForMaskedLM | 10.92 | 6.09 | 179.44% | 5.90 | 185.05% DebertaForQuestionAnswering | 14.29 | 7.70 | 185.59% | 7.26 | 196.75% DebertaV2ForMaskedLM | 15.47 | 10.22 | 151.32% | 9.34 | 165.55% DebertaV2ForQuestionAnswering | 14.98 | 6.11 | 245.28% | 6.28 | 238.40% DistilBertForMaskedLM | 8.37 | 8.70 | 96.30% | 8.22 | 101.92% DistilBertForQuestionAnswering | 10.21 | 10.54 | 96.88% | 10.39 | 98.36% DistillGPT2 | 8.77 | 6.78 | 129.40% | 6.31 | 138.88% ElectraForCausalLM | 10.32 | 4.70 | 219.45% | 4.60 | 224.29% ElectraForQuestionAnswering | 11.48 | 5.62 | 204.20% | 5.44 | 210.95% GPT2ForSequenceClassification | 6.21 | 5.72 | 108.50% | 5.58 | 111.26% GoogleFnet | 26.51 | 20.81 | 127.37% | 19.91 | 133.11% LayoutLMForMaskedLM | 12.09 | 7.99 | 151.28% | 7.66 | 157.80% LayoutLMForSequenceClassification | 10.62 | 6.49 | 163.67% | 6.25 | 169.95% M2M100ForConditionalGeneration | 14.98 | 10.20 | 146.79% | 9.89 | 151.42% MBartForCausalLM | 7.67 | 9.78 | 78.44% | 8.87 | 86.55% MBartForConditionalGeneration | 13.45 | 12.69 | 105.99% | 12.03 | 111.82% MT5ForConditionalGeneration | 19.96 | 5.32 | 375.37% | 5.08 | 393.01% MegatronBertForCausalLM | 13.22 | 7.86 | 168.07% | 7.18 | 184.01% MegatronBertForQuestionAnswering | 15.62 | 11.81 | 132.21% | 11.02 | 141.68% MobileBertForMaskedLM | 26.63 | 10.82 | 245.99% | 11.95 | 222.73% MobileBertForQuestionAnswering | 23.53 | 7.55 | 311.51% | 9.53 | 247.03% OPTForCausalLM | 7.33 | 7.64 | 95.93% | 7.56 | 96.90% PLBartForCausalLM | 8.73 | 7.63 | 114.40% | 7.37 | 118.58% PLBartForConditionalGeneration | 10.46 | 8.50 | 122.98% | 8.16 | 128.13% PegasusForCausalLM | 7.18 | 7.37 | 97.42% | 6.64 | 108.22% PegasusForConditionalGeneration | 16.47 | 16.66 | 98.87% | 14.18 | 116.13% RobertaForCausalLM | 10.30 | 9.95 | 103.52% | 9.52 | 108.25% RobertaForQuestionAnswering | 6.37 | 7.13 | 89.28% | 6.79 | 93.87% T5ForConditionalGeneration | 12.40 | 6.72 | 184.51% | 6.48 | 191.16% T5Small | 12.02 | 6.66 | 180.55% | 6.32 | 190.33% TrOCRForCausalLM | 14.12 | 13.31 | 106.11% | 12.45 | 113.41% XGLMForCausalLM | 16.48 | 6.23 | 264.52% | 6.35 | 259.51% XLNetLMHeadModel | 74.87 | 62.23 | 120.32% | 57.95 | 129.19% YituTechConvBert | 20.21 | 10.50 | 192.48% | 9.97 | 202.72% We are also seeing improvement ~9% on internal addmm benchmark This PR will also slightly reduce the compilation time on AMD max-autotune as before this change we assess every config with matrix_instr_nonkdim [0, 16] but we remove this and use 16 for all configs with this update. No CI to test the max-autotune perf currently but this will be enabled via https://github.com/pytorch/pytorch/pull/148672 after which we can investigate more tuning updates and config pruning Pull Request resolved: https://github.com/pytorch/pytorch/pull/147315 Approved by: https://github.com/jansel, https://github.com/eellison
This commit is contained in:
parent
886d9acb0d
commit
2299087220
|
|
@ -72,8 +72,7 @@ def mm_options(config, sym_m, sym_n, sym_k, layout):
|
|||
not inductor_config.force_same_precision
|
||||
or ((sym_m % 16) == 0 and (sym_n % 16) == 0 and (sym_k % 8) == 0)
|
||||
)
|
||||
return dict(
|
||||
GROUP_M=8,
|
||||
options_dict = dict(
|
||||
EVEN_K=even_k_symbolic,
|
||||
ALLOW_TF32=allow_tf32,
|
||||
USE_FAST_ACCUM=False, # Option for _scaled_mm
|
||||
|
|
@ -83,6 +82,13 @@ def mm_options(config, sym_m, sym_n, sym_k, layout):
|
|||
**config.kwargs,
|
||||
)
|
||||
|
||||
# If GROUP_M not specified then default to 8
|
||||
if "GROUP_M" not in config.kwargs:
|
||||
group_m = config.kwargs.get("GROUP_M", 8)
|
||||
options_dict["GROUP_M"] = group_m
|
||||
|
||||
return options_dict
|
||||
|
||||
|
||||
def persistent_mm_options(mat1, mat2):
|
||||
return dict(
|
||||
|
|
|
|||
|
|
@ -1282,6 +1282,7 @@ class TritonTemplate(KernelTemplate):
|
|||
),
|
||||
"num_stages": num_stages,
|
||||
"num_warps": num_warps,
|
||||
"GROUP_M": kwargs.get("GROUP_M", -1),
|
||||
"allow_tf32": str(kwargs.get("ALLOW_TF32", None)),
|
||||
"acc_type": str(kwargs.get("ACC_TYPE", None)),
|
||||
},
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import itertools
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
from threading import Lock
|
||||
from typing import Any, Callable, TYPE_CHECKING
|
||||
|
|
@ -14,12 +14,59 @@ from .virtualized import V
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator, Sequence
|
||||
from collections.abc import Generator
|
||||
|
||||
from triton import Config as TritonConfig
|
||||
|
||||
|
||||
class BaseConfigSingleton(type):
|
||||
@dataclasses.dataclass
|
||||
class BaseConfig:
|
||||
"""
|
||||
Base Gemm configuration used for most backends (CPU, CUDA)
|
||||
"""
|
||||
|
||||
block_m: int
|
||||
block_n: int
|
||||
block_k: int
|
||||
num_stages: int
|
||||
num_warps: int
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GemmConfig(BaseConfig):
|
||||
"""
|
||||
Gemm configuration used for most backends (CPU, CUDA)
|
||||
"""
|
||||
|
||||
group_m: int = 8
|
||||
|
||||
|
||||
ConvConfig = BaseConfig
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ROCmGemmConfig(GemmConfig):
|
||||
"""
|
||||
ROCm subclass for GEMMs, with AMD backend specific tuneable kernargs
|
||||
"""
|
||||
|
||||
matrix_instr_nonkdim: int = 16
|
||||
waves_per_eu: int = 0
|
||||
kpack: int = 2
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ROCmConvConfig(ConvConfig):
|
||||
"""
|
||||
ROCm subclass for Conv, with AMD backend specific tuneable kernargs
|
||||
"""
|
||||
|
||||
matrix_instr_nonkdim: int = 16
|
||||
waves_per_eu: int = 0
|
||||
kpack: int = 2
|
||||
|
||||
|
||||
class BaseHeuristicSingleton(type):
|
||||
"""
|
||||
Thread-safe implementation of single to be used in the config heuristic subclasses
|
||||
to ensure heavy __init__ calls are not repeatedly run
|
||||
|
|
@ -29,7 +76,7 @@ class BaseConfigSingleton(type):
|
|||
_lock: Lock = Lock()
|
||||
|
||||
def __call__(
|
||||
cls: BaseConfigSingleton, *args: Any, **kwargs: Any
|
||||
cls: BaseHeuristicSingleton, *args: Any, **kwargs: Any
|
||||
) -> BaseConfigHeuristic:
|
||||
with cls._lock:
|
||||
if cls not in cls._instances:
|
||||
|
|
@ -38,12 +85,7 @@ class BaseConfigSingleton(type):
|
|||
return cls._instances[cls]
|
||||
|
||||
|
||||
Config = namedtuple(
|
||||
"Config", ["block_m", "block_n", "block_k", "num_stages", "num_warps"]
|
||||
)
|
||||
|
||||
|
||||
class BaseConfigHeuristic(metaclass=BaseConfigSingleton):
|
||||
class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
|
||||
"""
|
||||
Base class for mm_configs, device specific triton kernels config inherit from here
|
||||
"""
|
||||
|
|
@ -52,36 +94,37 @@ class BaseConfigHeuristic(metaclass=BaseConfigSingleton):
|
|||
# List of dictionaries to store the kernel configs. Configs that evaluate to true
|
||||
# will be utilised on the target platform. The configs are as follows:
|
||||
# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
|
||||
self.mm_configs = [
|
||||
Config(32, 32, 16, 1, 2),
|
||||
Config(32, 32, 128, 2, 4),
|
||||
Config(32, 64, 32, 5, 8),
|
||||
Config(64, 32, 32, 5, 8),
|
||||
Config(64, 32, 128, 5, 4),
|
||||
Config(64, 64, 16, 2, 4),
|
||||
Config(64, 64, 32, 2, 4),
|
||||
Config(64, 64, 64, 3, 8),
|
||||
Config(64, 64, 128, 5, 4),
|
||||
Config(64, 128, 32, 3, 4),
|
||||
Config(64, 128, 32, 4, 8),
|
||||
Config(64, 128, 64, 3, 4),
|
||||
Config(64, 128, 128, 4, 4),
|
||||
Config(128, 64, 32, 3, 4),
|
||||
Config(128, 64, 32, 4, 8),
|
||||
Config(128, 128, 32, 2, 8),
|
||||
Config(128, 128, 32, 3, 4),
|
||||
Config(128, 128, 64, 3, 4),
|
||||
Config(128, 128, 64, 5, 8),
|
||||
self.mm_configs: list[BaseConfig] = [
|
||||
GemmConfig(32, 32, 16, 1, 2),
|
||||
GemmConfig(32, 32, 128, 2, 4),
|
||||
GemmConfig(32, 64, 32, 5, 8),
|
||||
GemmConfig(64, 32, 32, 5, 8),
|
||||
GemmConfig(64, 32, 128, 5, 4),
|
||||
GemmConfig(64, 64, 16, 2, 4),
|
||||
GemmConfig(64, 64, 32, 2, 4),
|
||||
GemmConfig(64, 64, 64, 3, 8),
|
||||
GemmConfig(64, 64, 128, 5, 4),
|
||||
GemmConfig(64, 128, 32, 3, 4),
|
||||
GemmConfig(64, 128, 32, 4, 8),
|
||||
GemmConfig(64, 128, 64, 3, 4),
|
||||
GemmConfig(64, 128, 128, 4, 4),
|
||||
GemmConfig(128, 64, 32, 3, 4),
|
||||
GemmConfig(128, 64, 32, 4, 8),
|
||||
GemmConfig(128, 128, 32, 2, 8),
|
||||
GemmConfig(128, 128, 32, 3, 4),
|
||||
GemmConfig(128, 128, 64, 3, 4),
|
||||
GemmConfig(128, 128, 64, 5, 8),
|
||||
]
|
||||
|
||||
# Exhaustive search for mm configs
|
||||
self.exhaustive_configs = [
|
||||
Config(BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
|
||||
self.exhaustive_configs: list[BaseConfig] = [
|
||||
GemmConfig(BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps, group_m)
|
||||
for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
|
||||
[16, 32, 64, 128, 256], repeat=3
|
||||
)
|
||||
for num_stages in [1, 2, 3, 4, 5]
|
||||
for num_warps in [2, 4, 8]
|
||||
for group_m in [8]
|
||||
]
|
||||
|
||||
# these are only used in tuned_mm when AutoHeuristic is enabled
|
||||
|
|
@ -89,220 +132,237 @@ class BaseConfigHeuristic(metaclass=BaseConfigSingleton):
|
|||
# when the learned heuristic is used, the learned heuristic reduces the number of configs down to 10
|
||||
# which saves compilation time (since less configs are autotuned) and potentially increase performance
|
||||
# because the learned heuristic might predict a config that is not part mm_configs
|
||||
self.extra_mm_configs = [
|
||||
Config(16, 32, 16, 3, 2),
|
||||
Config(16, 32, 32, 4, 2),
|
||||
Config(16, 32, 32, 5, 2),
|
||||
Config(64, 64, 128, 3, 4),
|
||||
Config(128, 64, 32, 2, 2),
|
||||
Config(128, 64, 64, 3, 8),
|
||||
Config(128, 64, 128, 4, 8),
|
||||
Config(128, 128, 32, 4, 4),
|
||||
Config(128, 128, 64, 3, 8),
|
||||
Config(128, 128, 64, 5, 4),
|
||||
self.extra_mm_configs: list[BaseConfig] = [
|
||||
GemmConfig(16, 32, 16, 3, 2),
|
||||
GemmConfig(16, 32, 32, 4, 2),
|
||||
GemmConfig(16, 32, 32, 5, 2),
|
||||
GemmConfig(64, 64, 128, 3, 4),
|
||||
GemmConfig(128, 64, 32, 2, 2),
|
||||
GemmConfig(128, 64, 64, 3, 8),
|
||||
GemmConfig(128, 64, 128, 4, 8),
|
||||
GemmConfig(128, 128, 32, 4, 4),
|
||||
GemmConfig(128, 128, 64, 3, 8),
|
||||
GemmConfig(128, 128, 64, 5, 4),
|
||||
]
|
||||
|
||||
self.int8_mm_configs = [
|
||||
Config(64, 64, 32, 2, 4),
|
||||
Config(64, 128, 32, 3, 4),
|
||||
Config(128, 64, 32, 3, 4),
|
||||
Config(64, 128, 32, 4, 8),
|
||||
Config(128, 64, 32, 4, 8),
|
||||
Config(64, 32, 32, 5, 8),
|
||||
Config(32, 64, 32, 5, 8),
|
||||
Config(128, 128, 32, 2, 8),
|
||||
Config(64, 64, 64, 3, 8),
|
||||
Config(128, 256, 128, 3, 8),
|
||||
Config(256, 128, 128, 3, 8),
|
||||
self.int8_mm_configs: list[BaseConfig] = [
|
||||
GemmConfig(64, 64, 32, 2, 4),
|
||||
GemmConfig(64, 128, 32, 3, 4),
|
||||
GemmConfig(128, 64, 32, 3, 4),
|
||||
GemmConfig(64, 128, 32, 4, 8),
|
||||
GemmConfig(128, 64, 32, 4, 8),
|
||||
GemmConfig(64, 32, 32, 5, 8),
|
||||
GemmConfig(32, 64, 32, 5, 8),
|
||||
GemmConfig(128, 128, 32, 2, 8),
|
||||
GemmConfig(64, 64, 64, 3, 8),
|
||||
GemmConfig(128, 256, 128, 3, 8),
|
||||
GemmConfig(256, 128, 128, 3, 8),
|
||||
]
|
||||
|
||||
self.mixed_mm_configs = [
|
||||
Config(16, 128, 256, 3, 4),
|
||||
Config(16, 128, 256, 5, 8),
|
||||
self.mixed_mm_configs: list[BaseConfig] = [
|
||||
GemmConfig(16, 128, 256, 3, 4),
|
||||
GemmConfig(16, 128, 256, 5, 8),
|
||||
]
|
||||
|
||||
self.persistent_mm_configs = [
|
||||
Config(128, 256, 64, 3, 8),
|
||||
Config(128, 128, 64, 3, 8),
|
||||
Config(128, 128, 128, 3, 8),
|
||||
Config(128, 128, 128, 3, 4),
|
||||
Config(128, 128, 64, 4, 8),
|
||||
self.persistent_mm_configs: list[BaseConfig] = [
|
||||
GemmConfig(128, 256, 64, 3, 8),
|
||||
GemmConfig(128, 128, 64, 3, 8),
|
||||
GemmConfig(128, 128, 128, 3, 8),
|
||||
GemmConfig(128, 128, 128, 3, 4),
|
||||
GemmConfig(128, 128, 64, 4, 8),
|
||||
]
|
||||
|
||||
self.scaled_mm_configs = [
|
||||
Config(128, 256, 32, 3, 8),
|
||||
Config(256, 128, 32, 3, 8),
|
||||
Config(256, 64, 32, 4, 4),
|
||||
Config(64, 256, 32, 4, 4),
|
||||
Config(128, 128, 32, 4, 4),
|
||||
Config(128, 64, 32, 4, 4),
|
||||
Config(64, 128, 32, 4, 4),
|
||||
Config(128, 32, 32, 4, 4),
|
||||
Config(64, 32, 32, 5, 2),
|
||||
Config(256, 128, 128, 3, 8),
|
||||
Config(256, 64, 128, 4, 4),
|
||||
Config(64, 256, 128, 4, 4),
|
||||
Config(128, 128, 128, 4, 4),
|
||||
Config(128, 64, 64, 4, 4),
|
||||
Config(64, 128, 64, 4, 4),
|
||||
Config(128, 32, 64, 4, 4),
|
||||
Config(64, 32, 64, 5, 2),
|
||||
Config(16, 32, 32, 2, 2),
|
||||
Config(16, 64, 32, 2, 2),
|
||||
Config(16, 128, 32, 2, 4),
|
||||
Config(16, 256, 32, 2, 4),
|
||||
Config(16, 32, 64, 2, 2),
|
||||
Config(16, 64, 64, 2, 2),
|
||||
Config(16, 128, 64, 2, 4),
|
||||
Config(16, 256, 64, 2, 4),
|
||||
Config(32, 32, 32, 2, 2),
|
||||
Config(32, 64, 32, 2, 2),
|
||||
Config(32, 128, 32, 2, 4),
|
||||
Config(32, 256, 32, 2, 4),
|
||||
Config(32, 32, 64, 2, 2),
|
||||
Config(32, 64, 64, 2, 2),
|
||||
Config(32, 128, 64, 2, 4),
|
||||
Config(32, 256, 64, 2, 4),
|
||||
Config(16, 32, 32, 3, 2),
|
||||
Config(16, 64, 32, 3, 2),
|
||||
Config(16, 128, 32, 3, 4),
|
||||
Config(16, 256, 32, 3, 4),
|
||||
Config(16, 32, 64, 3, 2),
|
||||
Config(16, 64, 64, 3, 2),
|
||||
Config(16, 128, 64, 3, 4),
|
||||
Config(16, 256, 64, 3, 4),
|
||||
Config(32, 32, 32, 3, 2),
|
||||
Config(32, 64, 32, 3, 2),
|
||||
Config(32, 128, 32, 3, 4),
|
||||
Config(32, 256, 32, 3, 4),
|
||||
Config(32, 32, 64, 3, 2),
|
||||
Config(32, 64, 64, 3, 2),
|
||||
Config(32, 128, 64, 3, 4),
|
||||
Config(32, 256, 64, 3, 4),
|
||||
Config(16, 32, 32, 4, 2),
|
||||
Config(16, 64, 32, 4, 2),
|
||||
Config(16, 128, 32, 4, 4),
|
||||
Config(16, 256, 32, 4, 4),
|
||||
Config(16, 32, 64, 4, 2),
|
||||
Config(16, 64, 64, 4, 2),
|
||||
Config(16, 128, 64, 4, 4),
|
||||
Config(16, 256, 64, 4, 4),
|
||||
Config(32, 32, 32, 4, 2),
|
||||
Config(32, 64, 32, 4, 2),
|
||||
Config(32, 128, 32, 4, 4),
|
||||
Config(32, 256, 32, 4, 4),
|
||||
Config(32, 32, 64, 4, 2),
|
||||
Config(32, 64, 64, 4, 2),
|
||||
Config(32, 128, 64, 4, 4),
|
||||
Config(32, 256, 64, 4, 4),
|
||||
Config(16, 32, 32, 5, 2),
|
||||
Config(16, 64, 32, 5, 2),
|
||||
Config(16, 128, 32, 5, 4),
|
||||
Config(16, 256, 32, 5, 4),
|
||||
Config(16, 32, 64, 5, 2),
|
||||
Config(16, 64, 64, 5, 2),
|
||||
Config(16, 128, 64, 5, 4),
|
||||
Config(16, 256, 64, 5, 4),
|
||||
Config(32, 32, 32, 5, 2),
|
||||
Config(32, 64, 32, 5, 2),
|
||||
Config(32, 128, 32, 5, 4),
|
||||
Config(32, 256, 32, 5, 4),
|
||||
Config(32, 32, 64, 5, 2),
|
||||
Config(32, 64, 64, 5, 2),
|
||||
Config(32, 128, 64, 5, 4),
|
||||
Config(32, 256, 64, 5, 4),
|
||||
Config(16, 32, 32, 6, 2),
|
||||
Config(16, 64, 32, 6, 2),
|
||||
Config(16, 128, 32, 6, 4),
|
||||
Config(16, 256, 32, 6, 4),
|
||||
Config(16, 32, 64, 6, 2),
|
||||
Config(16, 64, 64, 6, 2),
|
||||
Config(16, 128, 64, 6, 4),
|
||||
Config(16, 256, 64, 6, 4),
|
||||
Config(32, 32, 32, 6, 2),
|
||||
Config(32, 64, 32, 6, 2),
|
||||
Config(32, 128, 32, 6, 4),
|
||||
Config(32, 256, 32, 6, 4),
|
||||
Config(32, 32, 64, 6, 2),
|
||||
Config(32, 64, 64, 6, 2),
|
||||
Config(32, 128, 64, 6, 4),
|
||||
Config(32, 256, 64, 6, 4),
|
||||
self.scaled_mm_configs: list[BaseConfig] = [
|
||||
GemmConfig(128, 256, 32, 3, 8),
|
||||
GemmConfig(256, 128, 32, 3, 8),
|
||||
GemmConfig(256, 64, 32, 4, 4),
|
||||
GemmConfig(64, 256, 32, 4, 4),
|
||||
GemmConfig(128, 128, 32, 4, 4),
|
||||
GemmConfig(128, 64, 32, 4, 4),
|
||||
GemmConfig(64, 128, 32, 4, 4),
|
||||
GemmConfig(128, 32, 32, 4, 4),
|
||||
GemmConfig(64, 32, 32, 5, 2),
|
||||
GemmConfig(256, 128, 128, 3, 8),
|
||||
GemmConfig(256, 64, 128, 4, 4),
|
||||
GemmConfig(64, 256, 128, 4, 4),
|
||||
GemmConfig(128, 128, 128, 4, 4),
|
||||
GemmConfig(128, 64, 64, 4, 4),
|
||||
GemmConfig(64, 128, 64, 4, 4),
|
||||
GemmConfig(128, 32, 64, 4, 4),
|
||||
GemmConfig(64, 32, 64, 5, 2),
|
||||
GemmConfig(16, 32, 32, 2, 2),
|
||||
GemmConfig(16, 64, 32, 2, 2),
|
||||
GemmConfig(16, 128, 32, 2, 4),
|
||||
GemmConfig(16, 256, 32, 2, 4),
|
||||
GemmConfig(16, 32, 64, 2, 2),
|
||||
GemmConfig(16, 64, 64, 2, 2),
|
||||
GemmConfig(16, 128, 64, 2, 4),
|
||||
GemmConfig(16, 256, 64, 2, 4),
|
||||
GemmConfig(32, 32, 32, 2, 2),
|
||||
GemmConfig(32, 64, 32, 2, 2),
|
||||
GemmConfig(32, 128, 32, 2, 4),
|
||||
GemmConfig(32, 256, 32, 2, 4),
|
||||
GemmConfig(32, 32, 64, 2, 2),
|
||||
GemmConfig(32, 64, 64, 2, 2),
|
||||
GemmConfig(32, 128, 64, 2, 4),
|
||||
GemmConfig(32, 256, 64, 2, 4),
|
||||
GemmConfig(16, 32, 32, 3, 2),
|
||||
GemmConfig(16, 64, 32, 3, 2),
|
||||
GemmConfig(16, 128, 32, 3, 4),
|
||||
GemmConfig(16, 256, 32, 3, 4),
|
||||
GemmConfig(16, 32, 64, 3, 2),
|
||||
GemmConfig(16, 64, 64, 3, 2),
|
||||
GemmConfig(16, 128, 64, 3, 4),
|
||||
GemmConfig(16, 256, 64, 3, 4),
|
||||
GemmConfig(32, 32, 32, 3, 2),
|
||||
GemmConfig(32, 64, 32, 3, 2),
|
||||
GemmConfig(32, 128, 32, 3, 4),
|
||||
GemmConfig(32, 256, 32, 3, 4),
|
||||
GemmConfig(32, 32, 64, 3, 2),
|
||||
GemmConfig(32, 64, 64, 3, 2),
|
||||
GemmConfig(32, 128, 64, 3, 4),
|
||||
GemmConfig(32, 256, 64, 3, 4),
|
||||
GemmConfig(16, 32, 32, 4, 2),
|
||||
GemmConfig(16, 64, 32, 4, 2),
|
||||
GemmConfig(16, 128, 32, 4, 4),
|
||||
GemmConfig(16, 256, 32, 4, 4),
|
||||
GemmConfig(16, 32, 64, 4, 2),
|
||||
GemmConfig(16, 64, 64, 4, 2),
|
||||
GemmConfig(16, 128, 64, 4, 4),
|
||||
GemmConfig(16, 256, 64, 4, 4),
|
||||
GemmConfig(32, 32, 32, 4, 2),
|
||||
GemmConfig(32, 64, 32, 4, 2),
|
||||
GemmConfig(32, 128, 32, 4, 4),
|
||||
GemmConfig(32, 256, 32, 4, 4),
|
||||
GemmConfig(32, 32, 64, 4, 2),
|
||||
GemmConfig(32, 64, 64, 4, 2),
|
||||
GemmConfig(32, 128, 64, 4, 4),
|
||||
GemmConfig(32, 256, 64, 4, 4),
|
||||
GemmConfig(16, 32, 32, 5, 2),
|
||||
GemmConfig(16, 64, 32, 5, 2),
|
||||
GemmConfig(16, 128, 32, 5, 4),
|
||||
GemmConfig(16, 256, 32, 5, 4),
|
||||
GemmConfig(16, 32, 64, 5, 2),
|
||||
GemmConfig(16, 64, 64, 5, 2),
|
||||
GemmConfig(16, 128, 64, 5, 4),
|
||||
GemmConfig(16, 256, 64, 5, 4),
|
||||
GemmConfig(32, 32, 32, 5, 2),
|
||||
GemmConfig(32, 64, 32, 5, 2),
|
||||
GemmConfig(32, 128, 32, 5, 4),
|
||||
GemmConfig(32, 256, 32, 5, 4),
|
||||
GemmConfig(32, 32, 64, 5, 2),
|
||||
GemmConfig(32, 64, 64, 5, 2),
|
||||
GemmConfig(32, 128, 64, 5, 4),
|
||||
GemmConfig(32, 256, 64, 5, 4),
|
||||
GemmConfig(16, 32, 32, 6, 2),
|
||||
GemmConfig(16, 64, 32, 6, 2),
|
||||
GemmConfig(16, 128, 32, 6, 4),
|
||||
GemmConfig(16, 256, 32, 6, 4),
|
||||
GemmConfig(16, 32, 64, 6, 2),
|
||||
GemmConfig(16, 64, 64, 6, 2),
|
||||
GemmConfig(16, 128, 64, 6, 4),
|
||||
GemmConfig(16, 256, 64, 6, 4),
|
||||
GemmConfig(32, 32, 32, 6, 2),
|
||||
GemmConfig(32, 64, 32, 6, 2),
|
||||
GemmConfig(32, 128, 32, 6, 4),
|
||||
GemmConfig(32, 256, 32, 6, 4),
|
||||
GemmConfig(32, 32, 64, 6, 2),
|
||||
GemmConfig(32, 64, 64, 6, 2),
|
||||
GemmConfig(32, 128, 64, 6, 4),
|
||||
GemmConfig(32, 256, 64, 6, 4),
|
||||
]
|
||||
|
||||
self.scaled_persistent_mm_configs = [
|
||||
Config(128, 128, 64, 3, 8),
|
||||
Config(128, 128, 128, 3, 8),
|
||||
Config(128, 128, 128, 4, 8),
|
||||
Config(128, 128, 128, 4, 4),
|
||||
Config(128, 128, 128, 3, 4),
|
||||
Config(128, 128, 128, 5, 4),
|
||||
Config(128, 128, 128, 5, 8),
|
||||
Config(128, 128, 128, 6, 8),
|
||||
Config(128, 128, 64, 4, 8),
|
||||
self.scaled_persistent_mm_configs: list[BaseConfig] = [
|
||||
GemmConfig(128, 128, 64, 3, 8),
|
||||
GemmConfig(128, 128, 128, 3, 8),
|
||||
GemmConfig(128, 128, 128, 4, 8),
|
||||
GemmConfig(128, 128, 128, 4, 4),
|
||||
GemmConfig(128, 128, 128, 3, 4),
|
||||
GemmConfig(128, 128, 128, 5, 4),
|
||||
GemmConfig(128, 128, 128, 5, 8),
|
||||
GemmConfig(128, 128, 128, 6, 8),
|
||||
GemmConfig(128, 128, 64, 4, 8),
|
||||
]
|
||||
|
||||
# TODO: Unify with other gemm patterns, mm_plus_mm currently follows
|
||||
# slightly different pattern than rest
|
||||
self.mm_plus_mm_configs = [
|
||||
Config(64, 64, 32, 2, 4),
|
||||
Config(64, 64, 32, 3, 8),
|
||||
Config(64, 64, 32, 4, 16),
|
||||
Config(64, 32, 32, 4, 8),
|
||||
Config(32, 64, 32, 4, 8),
|
||||
Config(128, 128, 32, 1, 8),
|
||||
Config(64, 64, 64, 1, 8),
|
||||
Config(32, 32, 128, 1, 8),
|
||||
Config(64, 64, 16, 2, 4),
|
||||
Config(32, 32, 16, 1, 2),
|
||||
self.mm_plus_mm_configs: list[BaseConfig] = [
|
||||
GemmConfig(64, 64, 32, 2, 4),
|
||||
GemmConfig(64, 64, 32, 3, 8),
|
||||
GemmConfig(64, 64, 32, 4, 16),
|
||||
GemmConfig(64, 32, 32, 4, 8),
|
||||
GemmConfig(32, 64, 32, 4, 8),
|
||||
GemmConfig(128, 128, 32, 1, 8),
|
||||
GemmConfig(64, 64, 64, 1, 8),
|
||||
GemmConfig(32, 32, 128, 1, 8),
|
||||
GemmConfig(64, 64, 16, 2, 4),
|
||||
GemmConfig(32, 32, 16, 1, 2),
|
||||
]
|
||||
|
||||
self.conv_configs = [
|
||||
Config(64, 256, 16, 2, 4),
|
||||
Config(256, 64, 16, 2, 4),
|
||||
Config(1024, 16, 16, 1, 8),
|
||||
Config(128, 128, 32, 2, 8),
|
||||
Config(64, 64, 32, 2, 4),
|
||||
Config(64, 256, 32, 2, 8),
|
||||
Config(256, 64, 32, 2, 8),
|
||||
self.conv_configs: list[BaseConfig] = [
|
||||
ConvConfig(64, 256, 16, 2, 4),
|
||||
ConvConfig(256, 64, 16, 2, 4),
|
||||
ConvConfig(1024, 16, 16, 1, 8),
|
||||
ConvConfig(128, 128, 32, 2, 8),
|
||||
ConvConfig(64, 64, 32, 2, 4),
|
||||
ConvConfig(64, 256, 32, 2, 8),
|
||||
ConvConfig(256, 64, 32, 2, 8),
|
||||
]
|
||||
|
||||
def _finalize_mm_configs(
|
||||
self,
|
||||
configs: list[Config],
|
||||
configs: list[BaseConfig],
|
||||
) -> Generator[TritonConfig, None, None]:
|
||||
"""
|
||||
Finalizes configs after scaling, applying additional constraints.
|
||||
"""
|
||||
used = OrderedSet[Config]()
|
||||
used: OrderedSet[tuple[int, ...]] = OrderedSet()
|
||||
|
||||
max_mm_configs = config.test_configs.max_mm_configs
|
||||
|
||||
for block_m, block_n, block_k, num_stages, num_warps in configs:
|
||||
for conf in configs:
|
||||
# Each warp computes a 16x16 tile = 256 elements
|
||||
num_warps = min(num_warps, block_m * block_n // 256)
|
||||
num_warps = min(conf.num_warps, conf.block_m * conf.block_n // 256)
|
||||
|
||||
if (
|
||||
Config(block_m, block_n, block_k, num_stages, num_warps)
|
||||
) not in used and (max_mm_configs is None or len(used) < max_mm_configs):
|
||||
used.add(Config(block_m, block_n, block_k, num_stages, num_warps))
|
||||
yield self.triton_config(
|
||||
BLOCK_M=block_m,
|
||||
BLOCK_N=block_n,
|
||||
BLOCK_K=block_k,
|
||||
num_stages=num_stages,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
# Construct key for finding duplicate configs
|
||||
key: tuple[int, ...] = (
|
||||
conf.block_m,
|
||||
conf.block_n,
|
||||
conf.block_k,
|
||||
conf.num_stages,
|
||||
num_warps,
|
||||
)
|
||||
|
||||
# Check if gemm specific arg exists - add to key if does
|
||||
group_m = getattr(conf, "group_m", None)
|
||||
if group_m is not None:
|
||||
key += (group_m,)
|
||||
|
||||
if key not in used and (
|
||||
max_mm_configs is None or len(used) < max_mm_configs
|
||||
):
|
||||
used.add(key)
|
||||
kwargs = {
|
||||
"BLOCK_M": conf.block_m,
|
||||
"BLOCK_N": conf.block_n,
|
||||
"BLOCK_K": conf.block_k,
|
||||
"num_stages": conf.num_stages,
|
||||
"num_warps": num_warps,
|
||||
}
|
||||
if group_m is not None:
|
||||
kwargs["GROUP_M"] = group_m
|
||||
yield self.triton_config(**kwargs)
|
||||
|
||||
def _scale_mm_configs(
|
||||
self,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
configs: Sequence[Config],
|
||||
configs: list[BaseConfig],
|
||||
scale: float,
|
||||
has_int8_tensor: bool,
|
||||
exclude: Callable[[int, int, int], bool],
|
||||
) -> list[Config]:
|
||||
) -> list[BaseConfig]:
|
||||
"""
|
||||
Scales and filters matrix multiplication configs based on input size.
|
||||
"""
|
||||
|
|
@ -341,7 +401,8 @@ class BaseConfigHeuristic(metaclass=BaseConfigSingleton):
|
|||
|
||||
scaled_configs = []
|
||||
for c in configs:
|
||||
scaled_config = c._replace(
|
||||
scaled_config = dataclasses.replace(
|
||||
c,
|
||||
block_m=max(min(int(c.block_m * scale), m), min_block_size),
|
||||
block_n=max(min(int(c.block_n * scale), n), min_block_size),
|
||||
block_k=max(min(int(c.block_k * scale), k), min_block_size_k),
|
||||
|
|
@ -359,7 +420,7 @@ class BaseConfigHeuristic(metaclass=BaseConfigSingleton):
|
|||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
configs: Sequence[Config],
|
||||
configs: list[BaseConfig],
|
||||
has_int8_tensor: bool = False,
|
||||
scale: int = 1,
|
||||
exclude: Callable[[int, int, int], bool] = lambda m, n, k: False,
|
||||
|
|
@ -430,90 +491,160 @@ class ROCmConfigHeuristic(BaseConfigHeuristic):
|
|||
|
||||
self.default_num_stages = get_backend_num_stages()
|
||||
|
||||
self.mm_configs: list[BaseConfig] = [
|
||||
ROCmGemmConfig(
|
||||
16, 16, 256, self.default_num_stages, 4, group_m=4, waves_per_eu=2
|
||||
),
|
||||
ROCmGemmConfig(32, 16, 256, self.default_num_stages, 4, group_m=4),
|
||||
ROCmGemmConfig(
|
||||
32, 32, 16, self.default_num_stages, 4, group_m=8, waves_per_eu=2
|
||||
),
|
||||
ROCmGemmConfig(32, 32, 128, self.default_num_stages, 4, group_m=8),
|
||||
ROCmGemmConfig(32, 64, 64, self.default_num_stages, 4, group_m=8),
|
||||
ROCmGemmConfig(
|
||||
64, 16, 128, self.default_num_stages, 4, group_m=8, waves_per_eu=2
|
||||
),
|
||||
ROCmGemmConfig(64, 32, 32, self.default_num_stages, 4, group_m=8),
|
||||
ROCmGemmConfig(64, 32, 64, self.default_num_stages, 4, group_m=8),
|
||||
ROCmGemmConfig(64, 32, 64, self.default_num_stages, 8, group_m=8),
|
||||
ROCmGemmConfig(64, 32, 128, self.default_num_stages, 4, group_m=8),
|
||||
ROCmGemmConfig(64, 64, 16, self.default_num_stages, 4, group_m=8),
|
||||
ROCmGemmConfig(64, 64, 64, self.default_num_stages, 4, group_m=4),
|
||||
ROCmGemmConfig(64, 64, 128, self.default_num_stages, 8, group_m=16),
|
||||
ROCmGemmConfig(64, 64, 256, self.default_num_stages, 8, group_m=4),
|
||||
ROCmGemmConfig(
|
||||
64, 128, 32, self.default_num_stages, 4, group_m=4, waves_per_eu=2
|
||||
),
|
||||
ROCmGemmConfig(64, 128, 32, self.default_num_stages, 8, group_m=8),
|
||||
ROCmGemmConfig(64, 128, 64, self.default_num_stages, 8, group_m=4),
|
||||
ROCmGemmConfig(64, 128, 128, self.default_num_stages, 8, group_m=4),
|
||||
ROCmGemmConfig(128, 32, 32, self.default_num_stages, 4, group_m=8),
|
||||
ROCmGemmConfig(128, 32, 64, self.default_num_stages, 4, group_m=8),
|
||||
ROCmGemmConfig(
|
||||
128, 64, 32, self.default_num_stages, 4, group_m=8, waves_per_eu=2
|
||||
),
|
||||
ROCmGemmConfig(128, 64, 64, self.default_num_stages, 4, group_m=16),
|
||||
ROCmGemmConfig(128, 64, 128, self.default_num_stages, 8, group_m=4),
|
||||
ROCmGemmConfig(
|
||||
128, 128, 32, self.default_num_stages, 4, group_m=16, waves_per_eu=2
|
||||
),
|
||||
ROCmGemmConfig(128, 128, 32, self.default_num_stages, 8, group_m=16),
|
||||
ROCmGemmConfig(
|
||||
128, 128, 32, self.default_num_stages, 8, group_m=16, waves_per_eu=2
|
||||
),
|
||||
ROCmGemmConfig(128, 128, 64, self.default_num_stages, 4, group_m=16),
|
||||
ROCmGemmConfig(128, 128, 64, self.default_num_stages, 8, group_m=8),
|
||||
ROCmGemmConfig(128, 128, 128, self.default_num_stages, 8, group_m=16),
|
||||
ROCmGemmConfig(
|
||||
128, 256, 32, self.default_num_stages, 4, group_m=16, waves_per_eu=2
|
||||
),
|
||||
ROCmGemmConfig(128, 256, 64, self.default_num_stages, 8, group_m=4),
|
||||
ROCmGemmConfig(256, 64, 64, self.default_num_stages, 8, group_m=4),
|
||||
ROCmGemmConfig(
|
||||
256, 128, 32, self.default_num_stages, 4, group_m=4, waves_per_eu=2
|
||||
),
|
||||
ROCmGemmConfig(256, 128, 32, self.default_num_stages, 8, group_m=16),
|
||||
ROCmGemmConfig(256, 128, 64, self.default_num_stages, 8, group_m=4),
|
||||
ROCmGemmConfig(256, 256, 64, self.default_num_stages, 8, group_m=4),
|
||||
]
|
||||
|
||||
# Exhaustive search for mm configs
|
||||
self.exhaustive_configs = [
|
||||
Config(BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
|
||||
self.exhaustive_configs: list[BaseConfig] = [
|
||||
ROCmGemmConfig(
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
num_stages,
|
||||
num_warps,
|
||||
group_m,
|
||||
matrix_instr_nonkdim,
|
||||
waves_per_eu,
|
||||
kpack,
|
||||
)
|
||||
for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
|
||||
[16, 32, 64, 128, 256], repeat=3
|
||||
)
|
||||
for num_stages in [1, self.default_num_stages]
|
||||
for num_warps in [4, 8]
|
||||
for group_m in [4, 8, 16]
|
||||
for matrix_instr_nonkdim in [0, 16]
|
||||
for waves_per_eu in [0, 2]
|
||||
for kpack in [2]
|
||||
]
|
||||
|
||||
def _filter_configs(
|
||||
self, configs: list[Config], new_num_stages: int
|
||||
) -> list[Config]:
|
||||
filtered_configs = [
|
||||
c._replace(num_stages=self.default_num_stages) for c in configs
|
||||
]
|
||||
return filtered_configs
|
||||
self, configs: list[BaseConfig], new_num_stages: int
|
||||
) -> list[BaseConfig]:
|
||||
# TODO: _filter_configs can be removed once backend specific configs are added
|
||||
# for all methods
|
||||
for c in configs:
|
||||
c.num_stages = self.default_num_stages
|
||||
return configs
|
||||
|
||||
def _finalize_mm_configs(
|
||||
self,
|
||||
configs: list[Config],
|
||||
configs: list[BaseConfig],
|
||||
) -> Generator[TritonConfig, None, None]:
|
||||
used = OrderedSet[tuple[Config, int, int]]()
|
||||
"""
|
||||
Finalizes configs after scaling, applying additional constraints.
|
||||
"""
|
||||
used: OrderedSet[tuple[int, ...]] = OrderedSet()
|
||||
|
||||
max_mm_configs = config.test_configs.max_mm_configs
|
||||
for block_m, block_n, block_k, num_stages, num_warps in configs:
|
||||
# each warp computes 16x16 tile = 256
|
||||
num_warps = min(num_warps, block_m * block_n // 256)
|
||||
kpack = 2
|
||||
for matrix_instr_nonkdim in [0, 16]:
|
||||
if matrix_instr_nonkdim != 0 and (
|
||||
block_m % matrix_instr_nonkdim != 0
|
||||
or block_n % matrix_instr_nonkdim != 0
|
||||
):
|
||||
# block_m and block_n must be a multiple of matrix_instr_nonkdim
|
||||
continue
|
||||
if (
|
||||
Config(
|
||||
block_m,
|
||||
block_n,
|
||||
block_k,
|
||||
num_stages,
|
||||
num_warps,
|
||||
),
|
||||
matrix_instr_nonkdim,
|
||||
kpack,
|
||||
) not in used and (
|
||||
max_mm_configs is None or len(used) < max_mm_configs
|
||||
):
|
||||
used.add(
|
||||
(
|
||||
Config(
|
||||
block_m,
|
||||
block_n,
|
||||
block_k,
|
||||
num_stages,
|
||||
num_warps,
|
||||
),
|
||||
matrix_instr_nonkdim,
|
||||
kpack,
|
||||
)
|
||||
)
|
||||
|
||||
yield self.triton_config(
|
||||
BLOCK_M=block_m,
|
||||
BLOCK_N=block_n,
|
||||
BLOCK_K=block_k,
|
||||
num_stages=num_stages,
|
||||
num_warps=num_warps,
|
||||
matrix_instr_nonkdim=matrix_instr_nonkdim,
|
||||
kpack=kpack,
|
||||
)
|
||||
for conf in configs:
|
||||
# Each warp computes a 16x16 tile = 256 elements
|
||||
conf.num_warps = min(conf.num_warps, conf.block_m * conf.block_n // 256)
|
||||
|
||||
def get_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
filtered_configs = self._filter_configs(
|
||||
self.mm_configs, self.default_num_stages
|
||||
)
|
||||
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
||||
# Defaults for AMD triton backend kern args if not set
|
||||
matrix_instr_nonkdim = getattr(conf, "matrix_instr_nonkdim", 16)
|
||||
waves_per_eu = getattr(conf, "waves_per_eu", 0)
|
||||
kpack = getattr(conf, "kpack", 2)
|
||||
|
||||
def get_exhaustive_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
filtered_configs = self._filter_configs(
|
||||
self.exhaustive_configs, self.default_num_stages
|
||||
)
|
||||
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
||||
if matrix_instr_nonkdim != 0 and (
|
||||
conf.block_m % matrix_instr_nonkdim != 0
|
||||
or conf.block_n % matrix_instr_nonkdim != 0
|
||||
):
|
||||
# block_m and block_n must be a multiple of matrix_instr_nonkdim
|
||||
continue
|
||||
|
||||
# Construct key for finding duplicate configs
|
||||
key: tuple[int, ...] = (
|
||||
conf.block_m,
|
||||
conf.block_n,
|
||||
conf.block_k,
|
||||
conf.num_stages,
|
||||
conf.num_warps,
|
||||
waves_per_eu,
|
||||
matrix_instr_nonkdim,
|
||||
kpack,
|
||||
)
|
||||
|
||||
# Check if gemm specific arg exists - add to key if does
|
||||
group_m = getattr(conf, "group_m", None)
|
||||
if group_m is not None:
|
||||
key += (group_m,)
|
||||
|
||||
if waves_per_eu != 0:
|
||||
waves_per_eu = int(8 // conf.num_warps)
|
||||
|
||||
if key not in used and (
|
||||
max_mm_configs is None or len(used) < max_mm_configs
|
||||
):
|
||||
used.add(key)
|
||||
kwargs = {
|
||||
"BLOCK_M": conf.block_m,
|
||||
"BLOCK_N": conf.block_n,
|
||||
"BLOCK_K": conf.block_k,
|
||||
"num_stages": conf.num_stages,
|
||||
"num_warps": conf.num_warps,
|
||||
"matrix_instr_nonkdim": matrix_instr_nonkdim,
|
||||
"waves_per_eu": waves_per_eu,
|
||||
"kpack": kpack,
|
||||
}
|
||||
if group_m is not None:
|
||||
kwargs["GROUP_M"] = group_m
|
||||
yield self.triton_config(**kwargs)
|
||||
|
||||
def get_extra_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
filtered_configs = self._filter_configs(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user