mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[ROCm] ROCm-specific gemm tuning parameters" (#147388)
Summary: This diff reverts D69573225 / https://github.com/pytorch/pytorch/pull/143286 15% cold compile time regression, see https://fb.workplace.com/groups/1075192433118967/permalink/1608559059782299/ Test Plan: NA Differential Revision: D69790102 Pull Request resolved: https://github.com/pytorch/pytorch/pull/147388 Approved by: https://github.com/yanboliang
This commit is contained in:
parent
4ece056791
commit
465930ee81
|
|
@ -1,32 +1,20 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from typing import Any, Generator, Optional, TYPE_CHECKING
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
|
||||
from . import config
|
||||
from .codecache import write_text
|
||||
from .metrics import get_metric_table, is_metric_table_enabled
|
||||
from .runtime.hints import DeviceProperties, ReductionHint
|
||||
from .scheduler import BaseSchedulerNode, Scheduler, WhyNoFuse
|
||||
from .template_heuristics import (
|
||||
BaseConfigHeuristic,
|
||||
CPUConfigHeuristic,
|
||||
CUDAConfigHeuristic,
|
||||
ROCmConfigHeuristic,
|
||||
XPUConfigHeuristic,
|
||||
)
|
||||
from .virtualized import V
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from functools import partial
|
||||
|
||||
from triton import Config as TritonConfig
|
||||
|
||||
import torch
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from .codegen.simd_kernel_features import SIMDKernelFeatures
|
||||
|
|
@ -53,80 +41,6 @@ class InductorChoices:
|
|||
torch._inductor.virtualized.V.set_choices_handler(MyHeuristics())
|
||||
"""
|
||||
|
||||
def get_config_heuristics(
|
||||
self, device_type: Optional[str] = "cuda"
|
||||
) -> BaseConfigHeuristic:
|
||||
if device_type == "cuda":
|
||||
if torch.version.hip is None:
|
||||
return CUDAConfigHeuristic()
|
||||
else:
|
||||
return ROCmConfigHeuristic()
|
||||
elif device_type == "xpu":
|
||||
return XPUConfigHeuristic()
|
||||
elif device_type == "cpu":
|
||||
return CPUConfigHeuristic()
|
||||
else:
|
||||
return BaseConfigHeuristic()
|
||||
|
||||
# GEMM configs
|
||||
def get_base_mm_configs(
|
||||
self, device_type: Optional[str] = "cuda"
|
||||
) -> partial[Generator[TritonConfig, None, None]]:
|
||||
mm_heuristics = self.get_config_heuristics(device_type)
|
||||
if config.max_autotune_gemm_search_space != "EXHAUSTIVE":
|
||||
return mm_heuristics.get_mm_configs()
|
||||
else:
|
||||
return mm_heuristics.get_exhaustive_mm_configs()
|
||||
|
||||
def get_extra_mm_configs(
|
||||
self, device_type: Optional[str] = "cuda"
|
||||
) -> partial[Generator[TritonConfig, None, None]]:
|
||||
mm_heuristics = self.get_config_heuristics(device_type)
|
||||
return mm_heuristics.get_extra_mm_configs()
|
||||
|
||||
def get_int8_mm_configs(
|
||||
self, device_type: Optional[str] = "cuda"
|
||||
) -> partial[Generator[TritonConfig, None, None]]:
|
||||
mm_heuristics = self.get_config_heuristics(device_type)
|
||||
return mm_heuristics.get_int8_mm_configs()
|
||||
|
||||
def get_mixed_mm_configs(
|
||||
self, device_type: Optional[str] = "cuda"
|
||||
) -> partial[Generator[TritonConfig, None, None]]:
|
||||
mm_heuristics = self.get_config_heuristics(device_type)
|
||||
return mm_heuristics.get_mixed_mm_configs()
|
||||
|
||||
def get_persistent_mm_configs(
|
||||
self, device_type: Optional[str] = "cuda"
|
||||
) -> partial[Generator[TritonConfig, None, None]]:
|
||||
mm_heuristics = self.get_config_heuristics(device_type)
|
||||
return mm_heuristics.get_persistent_mm_configs()
|
||||
|
||||
def get_scaled_mm_configs(
|
||||
self, device_type: Optional[str] = "cuda"
|
||||
) -> partial[Generator[TritonConfig, None, None]]:
|
||||
mm_heuristics = self.get_config_heuristics(device_type)
|
||||
return mm_heuristics.get_scaled_mm_configs()
|
||||
|
||||
def get_scaled_persistent_mm_configs(
|
||||
self, device_type: Optional[str] = "cuda"
|
||||
) -> partial[Generator[TritonConfig, None, None]]:
|
||||
mm_heuristics = self.get_config_heuristics(device_type)
|
||||
return mm_heuristics.get_scaled_persistent_mm_configs()
|
||||
|
||||
def get_mm_plus_mm_configs(
|
||||
self, device_type: Optional[str] = "cuda"
|
||||
) -> partial[Generator[TritonConfig, None, None]]:
|
||||
mm_heuristics = self.get_config_heuristics(device_type)
|
||||
return mm_heuristics.get_mm_plus_mm_configs()
|
||||
|
||||
# Conv configs
|
||||
def get_conv_configs(
|
||||
self, device_type: Optional[str] = "cuda"
|
||||
) -> partial[Generator[TritonConfig, None, None]]:
|
||||
conv_heuristics = self.get_config_heuristics(device_type)
|
||||
return conv_heuristics.get_conv_configs()
|
||||
|
||||
def triton_kernel_kwargs(
|
||||
self,
|
||||
kernel_cls: type[TritonKernel],
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from .mm_common import (
|
|||
_is_static_problem,
|
||||
addmm_epilogue,
|
||||
mm_args,
|
||||
mm_config_kwargs,
|
||||
mm_configs,
|
||||
mm_options,
|
||||
)
|
||||
|
||||
|
|
@ -43,6 +43,12 @@ def _is_large_block_for_cpu(m, n, k):
|
|||
return m * n > 2**12
|
||||
|
||||
|
||||
def bmm_configs(m, n, k, *, device_type):
|
||||
if device_type == "cpu":
|
||||
return mm_configs(m, n, k, scale=0.5, exclude=_is_large_block_for_cpu)
|
||||
return mm_configs(m, n, k)
|
||||
|
||||
|
||||
bmm_template = TritonTemplate(
|
||||
name="bmm",
|
||||
grid=bmm_grid,
|
||||
|
|
@ -163,14 +169,8 @@ def tuned_bmm(mat1, mat2, *, layout=None):
|
|||
|
||||
# options to tune from
|
||||
choices = [aten_bmm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
|
||||
|
||||
device_type = ir.get_device_type(mat1)
|
||||
bmm_configs = V.choices.get_base_mm_configs(device_type)
|
||||
|
||||
if use_triton_template(layout):
|
||||
for config in bmm_configs(
|
||||
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
|
||||
):
|
||||
for config in bmm_configs(m, n, k, device_type=ir.get_device_type(mat1)):
|
||||
bmm_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=(mat1, mat2),
|
||||
|
|
@ -212,14 +212,8 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
|||
if use_aten_gemm_kernels()
|
||||
else []
|
||||
)
|
||||
|
||||
device_type = ir.get_device_type(mat1)
|
||||
bmm_configs = V.choices.get_base_mm_configs(device_type)
|
||||
|
||||
if use_triton_template(layout):
|
||||
for config in bmm_configs(
|
||||
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
|
||||
):
|
||||
for config in bmm_configs(m, n, k, device_type=ir.get_device_type(mat1)):
|
||||
bmm_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=(inp, mat1, mat2),
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Optional, TYPE_CHECKING, TypedDict
|
||||
from typing import cast, Optional, TYPE_CHECKING, TypedDict
|
||||
|
||||
import torch
|
||||
from torch._inductor.codegen.rocm.ck_conv_template import CKGroupedConvFwdTemplate
|
||||
|
|
@ -29,7 +29,7 @@ from ..utils import (
|
|||
use_triton_template,
|
||||
)
|
||||
from ..virtualized import V
|
||||
from .mm_common import mm_config_kwargs
|
||||
from .mm_common import build_rocm_gemm_configs, filtered_configs
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -59,6 +59,31 @@ def conv3d_grid(n, c, d, h, w, meta):
|
|||
)
|
||||
|
||||
|
||||
# List of dictionaries to store the kernel configs. Configs that evaluate to true
|
||||
# will be utilised on the target platform
|
||||
kernel_configs = [
|
||||
# "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps"
|
||||
{"config": (64, 256, 16, 2, 4), "cond": True},
|
||||
{"config": (256, 64, 16, 2, 4), "cond": True},
|
||||
{"config": (1024, 16, 16, 1, 8), "cond": True},
|
||||
{"config": (128, 128, 32, 2, 8), "cond": True},
|
||||
{"config": (64, 64, 32, 2, 4), "cond": True},
|
||||
{"config": (64, 256, 32, 2, 8), "cond": True},
|
||||
{"config": (256, 64, 32, 2, 8), "cond": True},
|
||||
]
|
||||
|
||||
# Create filtered list of configs based on conv
|
||||
platform_configs = tuple(
|
||||
cast(tuple[int, int, int, int, int], config["config"])
|
||||
for config in kernel_configs
|
||||
if config["cond"]
|
||||
)
|
||||
|
||||
# On ROCm convert num_stages to 1 as pipelining provides no benefit
|
||||
if torch.version.hip and torch.cuda.is_available():
|
||||
platform_configs = build_rocm_gemm_configs(platform_configs)
|
||||
|
||||
|
||||
def _is_large_block_for_cpu(m, n, k):
|
||||
# Thresholds are experimentally determined to reduce Triton CPU compile times
|
||||
if m > 256 or n > 256 or k > 256:
|
||||
|
|
@ -66,6 +91,19 @@ def _is_large_block_for_cpu(m, n, k):
|
|||
return m * n * k > 2**17
|
||||
|
||||
|
||||
def conv_configs(m, n, k, *, device_type, **kwargs):
|
||||
if device_type == "cpu":
|
||||
return filtered_configs(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
configs=platform_configs,
|
||||
scale=0.5,
|
||||
exclude=_is_large_block_for_cpu,
|
||||
)
|
||||
return filtered_configs(m, n, k, configs=platform_configs)
|
||||
|
||||
|
||||
LOOP_BODY_2D = """
|
||||
idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H
|
||||
idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W
|
||||
|
|
@ -457,8 +495,6 @@ def convolution(
|
|||
"groups": groups,
|
||||
}
|
||||
|
||||
device_type = ir.get_device_type(x)
|
||||
|
||||
if len(x.get_size()) == len(weight.get_size()) - 1:
|
||||
# add batch dimension to simplify rest of function
|
||||
return L[aten.squeeze](
|
||||
|
|
@ -473,7 +509,11 @@ def convolution(
|
|||
# Always convert conv1D to 2D for Intel GPU.
|
||||
# Only conv2D can be converted to channel last layout,
|
||||
# which have much better performance.
|
||||
if len(x.get_size()) == 3 and len(kernel_shape) == 1 and device_type == "xpu":
|
||||
if (
|
||||
len(x.get_size()) == 3
|
||||
and len(kernel_shape) == 1
|
||||
and ir.get_device_type(x) == "xpu"
|
||||
):
|
||||
kwargs.update(
|
||||
{
|
||||
"stride": (1,) + stride,
|
||||
|
|
@ -522,7 +562,7 @@ def convolution(
|
|||
):
|
||||
return convert_1x1_conv_to_mm(x, weight, bias)
|
||||
|
||||
if bias is not None and device_type != "cpu":
|
||||
if bias is not None and ir.get_device_type(x) != "cpu":
|
||||
# peel off the bias, cudnn is slower with it
|
||||
result = convolution(x, weight, None, **kwargs)
|
||||
return L[aten.add](
|
||||
|
|
@ -597,13 +637,11 @@ def convolution(
|
|||
):
|
||||
choices.append(aten_conv1x1_via_mm.bind(args, layout))
|
||||
|
||||
conv_configs = V.choices.get_conv_configs(device_type)
|
||||
|
||||
for cfg in conv_configs(
|
||||
sympy_product([x.get_size()[0], *x.get_size()[2:]]),
|
||||
out_chan,
|
||||
in_chan,
|
||||
**mm_config_kwargs(device_type, _is_large_block_for_cpu),
|
||||
device_type=ir.get_device_type(x),
|
||||
):
|
||||
if ndim == 2:
|
||||
conv2d_template.maybe_append_choice(
|
||||
|
|
|
|||
|
|
@ -43,12 +43,17 @@ from ..utils import (
|
|||
from .mm_common import (
|
||||
_is_static_problem,
|
||||
addmm_epilogue,
|
||||
extra_mm_configs,
|
||||
int8_mm_configs,
|
||||
mixed_mm_configs,
|
||||
mm_args,
|
||||
mm_config_kwargs,
|
||||
mm_configs,
|
||||
mm_grid,
|
||||
mm_options,
|
||||
persistent_mm_configs,
|
||||
persistent_mm_grid,
|
||||
persistent_mm_options,
|
||||
triton_config,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -335,6 +340,15 @@ def _is_large_block_for_cpu(m, n, k):
|
|||
return m * n > 2**13
|
||||
|
||||
|
||||
def mm_config_kwargs(device):
|
||||
if device == "cpu":
|
||||
return {
|
||||
"scale": 0.5,
|
||||
"exclude": _is_large_block_for_cpu,
|
||||
}
|
||||
return {}
|
||||
|
||||
|
||||
def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1):
|
||||
"""
|
||||
Giving torch.addmm a 1D tensor calls a different (faster) cublasLt
|
||||
|
|
@ -352,7 +366,6 @@ aten_bias_addmm = ExternKernelChoice(bias_addmm, None)
|
|||
@register_lowering(aten.mm, type_promotion_kind=None)
|
||||
def tuned_mm(mat1, mat2, *, layout=None):
|
||||
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
|
||||
device_type = ir.get_device_type(mat1)
|
||||
name = "mm"
|
||||
|
||||
aten_layout = layout
|
||||
|
|
@ -366,15 +379,8 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
|||
[aten_mm.bind((mat1, mat2), aten_layout)] if use_aten_gemm_kernels() else []
|
||||
)
|
||||
static_shape, is_nonzero = _is_static_problem(layout)
|
||||
|
||||
mm_configs = V.choices.get_base_mm_configs(device_type)
|
||||
persistent_mm_configs = V.choices.get_persistent_mm_configs(device_type)
|
||||
extra_mm_configs = V.choices.get_extra_mm_configs(device_type)
|
||||
|
||||
if is_nonzero and use_triton_template(layout):
|
||||
for config in mm_configs(
|
||||
m, n, k, *mm_config_kwargs(device_type, _is_large_block_for_cpu)
|
||||
):
|
||||
for config in mm_configs(m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))):
|
||||
mm_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=(mat1, mat2),
|
||||
|
|
@ -383,7 +389,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
|||
)
|
||||
if use_triton_tma_template(mat1, mat2):
|
||||
for config in persistent_mm_configs(
|
||||
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
|
||||
m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))
|
||||
):
|
||||
persistent_tma_mm_template.maybe_append_choice(
|
||||
choices,
|
||||
|
|
@ -422,7 +428,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
|||
always_included.append("extern_mm")
|
||||
num_choices_before_extra_configs = len(choices)
|
||||
for config in extra_mm_configs(
|
||||
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
|
||||
m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))
|
||||
):
|
||||
mm_template.maybe_append_choice(
|
||||
choices,
|
||||
|
|
@ -482,7 +488,6 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
|
|||
m, n, k, layout, mat1, mat2 = mm_args(
|
||||
mat1, mat2, layout=layout, out_dtype=torch.int32
|
||||
)
|
||||
device_type = ir.get_device_type(mat1)
|
||||
static_shape, is_nonzero = _is_static_problem(layout)
|
||||
use_cutlass = static_shape and is_nonzero and use_cutlass_template(layout, m, n, k)
|
||||
|
||||
|
|
@ -498,12 +503,9 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
|
|||
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
|
||||
choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True
|
||||
)
|
||||
|
||||
int8_mm_configs = V.choices.get_int8_mm_configs(device_type)
|
||||
|
||||
if is_nonzero and use_triton_template(layout, enable_int32=True):
|
||||
for config in int8_mm_configs(
|
||||
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
|
||||
m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))
|
||||
):
|
||||
mm_template.maybe_append_choice(
|
||||
choices,
|
||||
|
|
@ -530,7 +532,6 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
|
|||
@register_lowering(aten.addmm, type_promotion_kind=None)
|
||||
def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
||||
ordered_kwargs_for_cpp_kernel = ("beta", "alpha")
|
||||
device_type = ir.get_device_type(mat1)
|
||||
m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout)
|
||||
static_shape, is_nonzero = _is_static_problem(layout)
|
||||
if (not is_nonzero) or (not use_max_autotune()):
|
||||
|
|
@ -583,13 +584,8 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
|||
),
|
||||
)
|
||||
|
||||
mm_configs = V.choices.get_base_mm_configs(device_type)
|
||||
persistent_mm_configs = V.choices.get_persistent_mm_configs(device_type)
|
||||
|
||||
if is_nonzero and use_triton_template(layout):
|
||||
for config in mm_configs(
|
||||
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
|
||||
):
|
||||
for config in mm_configs(m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))):
|
||||
mm_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=(inp_expanded, mat1, mat2),
|
||||
|
|
@ -601,7 +597,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
|||
|
||||
if use_triton_tma_template(mat1, mat2):
|
||||
for config in persistent_mm_configs(
|
||||
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
|
||||
m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))
|
||||
):
|
||||
persistent_tma_mm_template.maybe_append_choice(
|
||||
choices,
|
||||
|
|
@ -764,7 +760,7 @@ def dims_are_int(dims):
|
|||
return all(isinstance(dim, int) for dim in dims)
|
||||
|
||||
|
||||
def try_heuristic(m, n, k, choices, mat1, mat2, mat2_dtype, layout, mm_heuristic):
|
||||
def try_heuristic(m, n, k, choices, mat1, mat2, mat2_dtype, layout):
|
||||
m, n, k = get_size_hints(mat1, mat2, m, n, k)
|
||||
if not dims_are_int([m, n, k]):
|
||||
return None
|
||||
|
|
@ -779,10 +775,35 @@ def try_heuristic(m, n, k, choices, mat1, mat2, mat2_dtype, layout, mm_heuristic
|
|||
not torch.cuda.get_device_capability() >= (8, 0)
|
||||
) or get_gpu_shared_memory() != 166912:
|
||||
return None
|
||||
elif m == 1 and (n % 16 != 0 or k % 16 != 0):
|
||||
|
||||
if m == 1 and (n % 16 != 0 or k % 16 != 0):
|
||||
return None
|
||||
else:
|
||||
return mm_heuristic.generate_mixed_mm_config()
|
||||
|
||||
if m <= 16 and n >= 4096 and k >= 4096:
|
||||
return triton_config(
|
||||
BLOCK_M=16,
|
||||
BLOCK_N=64,
|
||||
BLOCK_K=128,
|
||||
num_stages=5,
|
||||
num_warps=4,
|
||||
)
|
||||
elif m > 16 and m <= 32 and n >= 4096 and k >= 4096:
|
||||
return triton_config(
|
||||
BLOCK_M=32,
|
||||
BLOCK_N=32,
|
||||
BLOCK_K=128,
|
||||
num_stages=5,
|
||||
num_warps=4,
|
||||
)
|
||||
elif m > 32 and m <= 64 and n >= 4096 and k >= 4096:
|
||||
return triton_config(
|
||||
BLOCK_M=64,
|
||||
BLOCK_N=32,
|
||||
BLOCK_K=128,
|
||||
num_stages=5,
|
||||
num_warps=4,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def mm_autoheuristic(
|
||||
|
|
@ -879,7 +900,6 @@ def get_size_hints_strides(mat1, mat2):
|
|||
|
||||
def tuned_mixed_mm(mat1, mat2, mat2_dtype):
|
||||
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=None)
|
||||
device_type = ir.get_device_type(mat1)
|
||||
static_shape, is_nonzero = _is_static_problem(layout)
|
||||
|
||||
fallback = aten_fallback_mixed_mm.bind((mat1, mat2), layout)
|
||||
|
|
@ -905,15 +925,10 @@ def tuned_mixed_mm(mat1, mat2, mat2_dtype):
|
|||
choices = []
|
||||
|
||||
if not skip_triton:
|
||||
mm_heuristic = V.choices.get_config_heuristics(device_type)
|
||||
mixed_mm_configs = V.choices.get_mixed_mm_configs(device_type)
|
||||
|
||||
b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "")
|
||||
if static_shape and inductor_config.mixed_mm_choice == "heuristic":
|
||||
choices = []
|
||||
config = try_heuristic(
|
||||
m, n, k, choices, mat1, mat2, mat2_dtype, layout, mm_heuristic
|
||||
)
|
||||
config = try_heuristic(m, n, k, choices, mat1, mat2, mat2_dtype, layout)
|
||||
if config is not None:
|
||||
mm_template.maybe_append_choice(
|
||||
choices,
|
||||
|
|
@ -924,13 +939,12 @@ def tuned_mixed_mm(mat1, mat2, mat2_dtype):
|
|||
choices.append(fallback)
|
||||
|
||||
has_int8_tensor = _is_int8_mat(mat1) or _is_int8_mat(mat2)
|
||||
|
||||
for config in mixed_mm_configs(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
has_int8_tensor=has_int8_tensor,
|
||||
**mm_config_kwargs(device_type, _is_large_block_for_cpu),
|
||||
**mm_config_kwargs(ir.get_device_type(mat1)),
|
||||
):
|
||||
mm_template.maybe_append_choice(
|
||||
choices,
|
||||
|
|
@ -987,15 +1001,13 @@ def tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype, *, layout=None):
|
|||
m, n, k, layout, mat1, mat2, mat3 = mm_args(
|
||||
mat1, mat2, mat3, layout=layout, out_dtype=out_dtype
|
||||
)
|
||||
device_type = ir.get_device_type(mat1)
|
||||
|
||||
def mul_epilogue(v1, v2):
|
||||
return V.ops.mul(v1, v2)
|
||||
|
||||
choices: list[dict[Any, Any]] = []
|
||||
int8_mm_configs = V.choices.get_int8_mm_configs(device_type)
|
||||
for config in int8_mm_configs(
|
||||
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
|
||||
m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))
|
||||
):
|
||||
mm_template.maybe_append_choice(
|
||||
choices,
|
||||
|
|
|
|||
|
|
@ -1,22 +1,437 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Any
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch._inductor.select_algorithm import realize_inputs
|
||||
from torch._inductor.virtualized import V
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from .. import config as inductor_config
|
||||
from ..codegen.wrapper import PythonWrapperCodegen
|
||||
from ..ir import Layout
|
||||
from ..utils import ceildiv as cdiv, get_num_sms, TMA_DESCRIPTOR_SIZE
|
||||
from ..runtime.runtime_utils import next_power_of_2
|
||||
from ..utils import (
|
||||
ceildiv as cdiv,
|
||||
get_backend_num_stages,
|
||||
get_num_sms,
|
||||
TMA_DESCRIPTOR_SIZE,
|
||||
)
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def triton_config(num_stages, num_warps, **kwargs):
|
||||
from triton import Config # type: ignore[attr-defined]
|
||||
|
||||
return Config(kwargs, num_stages=num_stages, num_warps=num_warps)
|
||||
|
||||
|
||||
def build_rocm_gemm_configs(configs):
|
||||
rocm_num_stages = get_backend_num_stages()
|
||||
return tuple((c[0], c[1], c[2], rocm_num_stages, c[4]) for c in configs)
|
||||
|
||||
|
||||
def filtered_configs(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
configs: Sequence[tuple[int, int, int, int, int]],
|
||||
has_int8_tensor=False,
|
||||
scale=1,
|
||||
exclude=lambda m, n, k: False,
|
||||
):
|
||||
"""
|
||||
Heuristic to shrink configs when they are bigger than the input size
|
||||
|
||||
:param scale: scale factor applied to the config values
|
||||
:param exclude: whether a given config should be excluded
|
||||
"""
|
||||
from torch._inductor import config
|
||||
|
||||
max_mm_configs = config.test_configs.max_mm_configs
|
||||
|
||||
min_block_size = 16
|
||||
# block_k=16 seems to be causing issues
|
||||
# see: https://github.com/triton-lang/triton/issues/2156#issuecomment-1695897424
|
||||
min_block_size_k = 32 if has_int8_tensor else 16
|
||||
m = max(
|
||||
next_power_of_2(
|
||||
V.graph.sizevars.size_hint(
|
||||
m, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
|
||||
)
|
||||
),
|
||||
min_block_size,
|
||||
)
|
||||
n = max(
|
||||
next_power_of_2(
|
||||
V.graph.sizevars.size_hint(
|
||||
n, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
|
||||
)
|
||||
),
|
||||
min_block_size,
|
||||
)
|
||||
k = max(
|
||||
next_power_of_2(
|
||||
V.graph.sizevars.size_hint(
|
||||
k, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
|
||||
)
|
||||
),
|
||||
min_block_size_k,
|
||||
)
|
||||
used = OrderedSet[tuple[int, int, int, int, int, int]]()
|
||||
for block_m, block_n, block_k, num_stages, num_warps in configs:
|
||||
# shrink configs for small sizes
|
||||
block_m = max(min(int(block_m * scale), m), min_block_size)
|
||||
block_n = max(min(int(block_n * scale), n), min_block_size)
|
||||
block_k = max(min(int(block_k * scale), k), min_block_size_k)
|
||||
|
||||
if exclude(block_m, block_n, block_k):
|
||||
continue
|
||||
|
||||
# each warp computes 16x16 tile = 256
|
||||
num_warps = min(num_warps, block_m * block_n // 256)
|
||||
if torch.version.hip:
|
||||
for matrix_instr_nonkdim in [0, 16]:
|
||||
if matrix_instr_nonkdim != 0 and (
|
||||
block_m % matrix_instr_nonkdim != 0
|
||||
or block_n % matrix_instr_nonkdim != 0
|
||||
):
|
||||
# block_m and block_n must be a multiple of matrix_instr_nonkdim
|
||||
continue
|
||||
if (
|
||||
block_m,
|
||||
block_n,
|
||||
block_k,
|
||||
num_stages,
|
||||
num_warps,
|
||||
matrix_instr_nonkdim,
|
||||
) not in used and (
|
||||
max_mm_configs is None or len(used) < max_mm_configs
|
||||
):
|
||||
used.add(
|
||||
(
|
||||
block_m,
|
||||
block_n,
|
||||
block_k,
|
||||
num_stages,
|
||||
num_warps,
|
||||
matrix_instr_nonkdim,
|
||||
)
|
||||
)
|
||||
yield triton_config(
|
||||
BLOCK_M=block_m,
|
||||
BLOCK_N=block_n,
|
||||
BLOCK_K=block_k,
|
||||
num_stages=num_stages,
|
||||
num_warps=num_warps,
|
||||
matrix_instr_nonkdim=matrix_instr_nonkdim,
|
||||
)
|
||||
else:
|
||||
if (block_m, block_n, block_k, num_stages, num_warps, 0) not in used and (
|
||||
max_mm_configs is None or len(used) < max_mm_configs
|
||||
):
|
||||
used.add((block_m, block_n, block_k, num_stages, num_warps, 0))
|
||||
yield triton_config(
|
||||
BLOCK_M=block_m,
|
||||
BLOCK_N=block_n,
|
||||
BLOCK_K=block_k,
|
||||
num_stages=num_stages,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
|
||||
|
||||
# List of dictionaries to store the kernel configs. Configs that evaluate to true
|
||||
# will be utilised on the target platform. The configs are as follows:
|
||||
# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
|
||||
mm_kernel_configs = (
|
||||
[
|
||||
{"config": (32, 32, 16, 1, 2), "cond": True},
|
||||
{"config": (32, 32, 128, 2, 4), "cond": True},
|
||||
{"config": (32, 64, 32, 5, 8), "cond": True},
|
||||
{"config": (64, 32, 32, 5, 8), "cond": True},
|
||||
{"config": (64, 32, 128, 5, 4), "cond": True},
|
||||
{"config": (64, 64, 16, 2, 4), "cond": True},
|
||||
{"config": (64, 64, 32, 2, 4), "cond": True},
|
||||
{"config": (64, 64, 64, 3, 8), "cond": True},
|
||||
{"config": (64, 64, 128, 5, 4), "cond": True},
|
||||
{"config": (64, 128, 32, 3, 4), "cond": True},
|
||||
{"config": (64, 128, 32, 4, 8), "cond": True},
|
||||
{"config": (64, 128, 64, 3, 4), "cond": True},
|
||||
{"config": (64, 128, 128, 4, 4), "cond": True},
|
||||
{"config": (128, 64, 32, 3, 4), "cond": True},
|
||||
{"config": (128, 64, 32, 4, 8), "cond": True},
|
||||
{"config": (128, 128, 32, 2, 8), "cond": True},
|
||||
{"config": (128, 128, 32, 3, 4), "cond": True},
|
||||
{"config": (128, 128, 64, 3, 4), "cond": True},
|
||||
{"config": (128, 128, 64, 5, 8), "cond": True},
|
||||
]
|
||||
if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE"
|
||||
else [
|
||||
{"config": (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps), "cond": True}
|
||||
for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
|
||||
[16, 32, 64, 128, 256], repeat=3
|
||||
)
|
||||
for num_stages in [1, 2, 3, 4, 5]
|
||||
for num_warps in [2, 4, 8]
|
||||
]
|
||||
)
|
||||
|
||||
# these are only used in tuned_mm when AutoHeuristic is enabled
|
||||
# the idea is that when AutoHeuristic collects data to learn a heuristic, more configs are autotuned
|
||||
# when the learned heuristic is used, the learned heuristic reduces the number of configs down to 10
|
||||
# which saves compilation time (since less configs are autotuned) and potentially increase performance
|
||||
# because the learned heuristic might predict a config that is not part mm_configs
|
||||
extra_mm_kernel_configs = [
|
||||
{"config": (16, 32, 16, 3, 2), "cond": True},
|
||||
{"config": (16, 32, 32, 4, 2), "cond": True},
|
||||
{"config": (16, 32, 32, 5, 2), "cond": True},
|
||||
{"config": (64, 64, 128, 3, 4), "cond": True},
|
||||
{"config": (128, 64, 32, 2, 2), "cond": True},
|
||||
{"config": (128, 64, 64, 3, 8), "cond": True},
|
||||
{"config": (128, 64, 128, 4, 8), "cond": True},
|
||||
{"config": (128, 128, 32, 4, 4), "cond": True},
|
||||
{"config": (128, 128, 64, 3, 8), "cond": True},
|
||||
{"config": (128, 128, 64, 5, 4), "cond": True},
|
||||
]
|
||||
|
||||
int8_mm_kernel_configs = [
|
||||
{"config": (64, 64, 32, 2, 4), "cond": True},
|
||||
{"config": (64, 128, 32, 3, 4), "cond": True},
|
||||
{"config": (128, 64, 32, 3, 4), "cond": True},
|
||||
{"config": (64, 128, 32, 4, 8), "cond": True},
|
||||
{"config": (128, 64, 32, 4, 8), "cond": True},
|
||||
{"config": (64, 32, 32, 5, 8), "cond": True},
|
||||
{"config": (32, 64, 32, 5, 8), "cond": True},
|
||||
{"config": (128, 128, 32, 2, 8), "cond": True},
|
||||
{"config": (64, 64, 64, 3, 8), "cond": True},
|
||||
# {"config": (32, 32, 128, 2, 4), "cond": True},
|
||||
# {"config": (64, 64, 16, 2, 4), "cond": True},
|
||||
# {"config": (32, 32, 16, 1, 2), "cond": True},
|
||||
{"config": (128, 256, 128, 3, 8), "cond": True},
|
||||
{"config": (256, 128, 128, 3, 8), "cond": True},
|
||||
]
|
||||
|
||||
# Mixed precision kernel configs for small sizes of m for mm's like (16, 8192) x (8192, 8192).
|
||||
mixed_mm_kernel_configs_small_m = [
|
||||
{"config": (16, 128, 256, 3, 4), "cond": True},
|
||||
{"config": (16, 128, 256, 5, 8), "cond": True},
|
||||
]
|
||||
|
||||
mixed_mm_kernel_configs = (
|
||||
mm_kernel_configs + mixed_mm_kernel_configs_small_m
|
||||
if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE"
|
||||
else mm_kernel_configs
|
||||
)
|
||||
|
||||
persistent_mm_kernel_configs = [
|
||||
{"config": (128, 256, 64, 3, 8), "cond": True},
|
||||
{"config": (128, 128, 64, 3, 8), "cond": True},
|
||||
{"config": (128, 128, 128, 3, 8), "cond": True},
|
||||
{"config": (128, 128, 128, 3, 4), "cond": True},
|
||||
{"config": (128, 128, 64, 4, 8), "cond": True},
|
||||
]
|
||||
|
||||
scaled_mm_kernel_configs = [
|
||||
{"config": (128, 256, 32, 3, 8), "cond": True},
|
||||
{"config": (256, 128, 32, 3, 8), "cond": True},
|
||||
{"config": (256, 64, 32, 4, 4), "cond": True},
|
||||
{"config": (64, 256, 32, 4, 4), "cond": True},
|
||||
{"config": (128, 128, 32, 4, 4), "cond": True},
|
||||
{"config": (128, 64, 32, 4, 4), "cond": True},
|
||||
{"config": (64, 128, 32, 4, 4), "cond": True},
|
||||
{"config": (128, 32, 32, 4, 4), "cond": True},
|
||||
{"config": (64, 32, 32, 5, 2), "cond": True},
|
||||
{"config": (256, 128, 128, 3, 8), "cond": True},
|
||||
{"config": (256, 64, 128, 4, 4), "cond": True},
|
||||
{"config": (64, 256, 128, 4, 4), "cond": True},
|
||||
{"config": (128, 128, 128, 4, 4), "cond": True},
|
||||
{"config": (128, 64, 64, 4, 4), "cond": True},
|
||||
{"config": (64, 128, 64, 4, 4), "cond": True},
|
||||
{"config": (128, 32, 64, 4, 4), "cond": True},
|
||||
{"config": (64, 32, 64, 5, 2), "cond": True},
|
||||
{"config": (16, 32, 32, 2, 2), "cond": True},
|
||||
{"config": (16, 64, 32, 2, 2), "cond": True},
|
||||
{"config": (16, 128, 32, 2, 4), "cond": True},
|
||||
{"config": (16, 256, 32, 2, 4), "cond": True},
|
||||
{"config": (16, 32, 64, 2, 2), "cond": True},
|
||||
{"config": (16, 64, 64, 2, 2), "cond": True},
|
||||
{"config": (16, 128, 64, 2, 4), "cond": True},
|
||||
{"config": (16, 256, 64, 2, 4), "cond": True},
|
||||
{"config": (32, 32, 32, 2, 2), "cond": True},
|
||||
{"config": (32, 64, 32, 2, 2), "cond": True},
|
||||
{"config": (32, 128, 32, 2, 4), "cond": True},
|
||||
{"config": (32, 256, 32, 2, 4), "cond": True},
|
||||
{"config": (32, 32, 64, 2, 2), "cond": True},
|
||||
{"config": (32, 64, 64, 2, 2), "cond": True},
|
||||
{"config": (32, 128, 64, 2, 4), "cond": True},
|
||||
{"config": (32, 256, 64, 2, 4), "cond": True},
|
||||
{"config": (16, 32, 32, 3, 2), "cond": True},
|
||||
{"config": (16, 64, 32, 3, 2), "cond": True},
|
||||
{"config": (16, 128, 32, 3, 4), "cond": True},
|
||||
{"config": (16, 256, 32, 3, 4), "cond": True},
|
||||
{"config": (16, 32, 64, 3, 2), "cond": True},
|
||||
{"config": (16, 64, 64, 3, 2), "cond": True},
|
||||
{"config": (16, 128, 64, 3, 4), "cond": True},
|
||||
{"config": (16, 256, 64, 3, 4), "cond": True},
|
||||
{"config": (32, 32, 32, 3, 2), "cond": True},
|
||||
{"config": (32, 64, 32, 3, 2), "cond": True},
|
||||
{"config": (32, 128, 32, 3, 4), "cond": True},
|
||||
{"config": (32, 256, 32, 3, 4), "cond": True},
|
||||
{"config": (32, 32, 64, 3, 2), "cond": True},
|
||||
{"config": (32, 64, 64, 3, 2), "cond": True},
|
||||
{"config": (32, 128, 64, 3, 4), "cond": True},
|
||||
{"config": (32, 256, 64, 3, 4), "cond": True},
|
||||
{"config": (16, 32, 32, 4, 2), "cond": True},
|
||||
{"config": (16, 64, 32, 4, 2), "cond": True},
|
||||
{"config": (16, 128, 32, 4, 4), "cond": True},
|
||||
{"config": (16, 256, 32, 4, 4), "cond": True},
|
||||
{"config": (16, 32, 64, 4, 2), "cond": True},
|
||||
{"config": (16, 64, 64, 4, 2), "cond": True},
|
||||
{"config": (16, 128, 64, 4, 4), "cond": True},
|
||||
{"config": (16, 256, 64, 4, 4), "cond": True},
|
||||
{"config": (32, 32, 32, 4, 2), "cond": True},
|
||||
{"config": (32, 64, 32, 4, 2), "cond": True},
|
||||
{"config": (32, 128, 32, 4, 4), "cond": True},
|
||||
{"config": (32, 256, 32, 4, 4), "cond": True},
|
||||
{"config": (32, 32, 64, 4, 2), "cond": True},
|
||||
{"config": (32, 64, 64, 4, 2), "cond": True},
|
||||
{"config": (32, 128, 64, 4, 4), "cond": True},
|
||||
{"config": (32, 256, 64, 4, 4), "cond": True},
|
||||
{"config": (16, 32, 32, 5, 2), "cond": True},
|
||||
{"config": (16, 64, 32, 5, 2), "cond": True},
|
||||
{"config": (16, 128, 32, 5, 4), "cond": True},
|
||||
{"config": (16, 256, 32, 5, 4), "cond": True},
|
||||
{"config": (16, 32, 64, 5, 2), "cond": True},
|
||||
{"config": (16, 64, 64, 5, 2), "cond": True},
|
||||
{"config": (16, 128, 64, 5, 4), "cond": True},
|
||||
{"config": (16, 256, 64, 5, 4), "cond": True},
|
||||
{"config": (32, 32, 32, 5, 2), "cond": True},
|
||||
{"config": (32, 64, 32, 5, 2), "cond": True},
|
||||
{"config": (32, 128, 32, 5, 4), "cond": True},
|
||||
{"config": (32, 256, 32, 5, 4), "cond": True},
|
||||
{"config": (32, 32, 64, 5, 2), "cond": True},
|
||||
{"config": (32, 64, 64, 5, 2), "cond": True},
|
||||
{"config": (32, 128, 64, 5, 4), "cond": True},
|
||||
{"config": (32, 256, 64, 5, 4), "cond": True},
|
||||
{"config": (16, 32, 32, 6, 2), "cond": True},
|
||||
{"config": (16, 64, 32, 6, 2), "cond": True},
|
||||
{"config": (16, 128, 32, 6, 4), "cond": True},
|
||||
{"config": (16, 256, 32, 6, 4), "cond": True},
|
||||
{"config": (16, 32, 64, 6, 2), "cond": True},
|
||||
{"config": (16, 64, 64, 6, 2), "cond": True},
|
||||
{"config": (16, 128, 64, 6, 4), "cond": True},
|
||||
{"config": (16, 256, 64, 6, 4), "cond": True},
|
||||
{"config": (32, 32, 32, 6, 2), "cond": True},
|
||||
{"config": (32, 64, 32, 6, 2), "cond": True},
|
||||
{"config": (32, 128, 32, 6, 4), "cond": True},
|
||||
{"config": (32, 256, 32, 6, 4), "cond": True},
|
||||
{"config": (32, 32, 64, 6, 2), "cond": True},
|
||||
{"config": (32, 64, 64, 6, 2), "cond": True},
|
||||
{"config": (32, 128, 64, 6, 4), "cond": True},
|
||||
{"config": (32, 256, 64, 6, 4), "cond": True},
|
||||
]
|
||||
|
||||
scaled_persistent_mm_kernel_configs = [
|
||||
{"config": (128, 128, 64, 3, 8), "cond": True},
|
||||
{"config": (128, 128, 128, 3, 8), "cond": True},
|
||||
{"config": (128, 128, 128, 4, 8), "cond": True},
|
||||
{"config": (128, 128, 128, 4, 4), "cond": True},
|
||||
{"config": (128, 128, 128, 3, 4), "cond": True},
|
||||
{"config": (128, 128, 128, 5, 4), "cond": True},
|
||||
{"config": (128, 128, 128, 5, 8), "cond": True},
|
||||
{"config": (128, 128, 128, 6, 8), "cond": True},
|
||||
{"config": (128, 128, 64, 4, 8), "cond": True},
|
||||
]
|
||||
|
||||
|
||||
# Create filtered list of configs based on cond evaluation
|
||||
mm_platform_configs = tuple(
|
||||
cast(tuple[int, int, int, int, int], config["config"])
|
||||
for config in mm_kernel_configs
|
||||
if config["cond"]
|
||||
)
|
||||
extra_mm_platform_configs = tuple(
|
||||
cast(tuple[int, int, int, int, int], config["config"])
|
||||
for config in extra_mm_kernel_configs
|
||||
if config["cond"]
|
||||
)
|
||||
int8_platform_configs = tuple(
|
||||
cast(tuple[int, int, int, int, int], config["config"])
|
||||
for config in int8_mm_kernel_configs
|
||||
if config["cond"]
|
||||
)
|
||||
mixed_mm_platform_configs = tuple(
|
||||
cast(tuple[int, int, int, int, int], config["config"])
|
||||
for config in mixed_mm_kernel_configs
|
||||
if config["cond"]
|
||||
)
|
||||
persistent_mm_platform_configs = tuple(
|
||||
cast(tuple[int, int, int, int, int], config["config"])
|
||||
for config in persistent_mm_kernel_configs
|
||||
if config["cond"]
|
||||
)
|
||||
scaled_mm_platform_configs = tuple(
|
||||
cast(tuple[int, int, int, int, int], config["config"])
|
||||
for config in scaled_mm_kernel_configs
|
||||
if config["cond"]
|
||||
)
|
||||
scaled_persistent_mm_platform_configs = tuple(
|
||||
cast(tuple[int, int, int, int, int], config["config"])
|
||||
for config in scaled_persistent_mm_kernel_configs
|
||||
if config["cond"]
|
||||
)
|
||||
|
||||
# On ROCm convert num_stages to improve performance
|
||||
if torch.version.hip and torch.cuda.is_available():
|
||||
mm_platform_configs = build_rocm_gemm_configs(mm_platform_configs)
|
||||
extra_mm_platform_configs = build_rocm_gemm_configs(extra_mm_platform_configs)
|
||||
int8_platform_configs = build_rocm_gemm_configs(int8_platform_configs)
|
||||
mixed_mm_platform_configs = build_rocm_gemm_configs(mixed_mm_platform_configs)
|
||||
scaled_mm_platform_configs = build_rocm_gemm_configs(scaled_mm_platform_configs)
|
||||
|
||||
mm_configs = functools.partial(
|
||||
filtered_configs,
|
||||
configs=mm_platform_configs,
|
||||
)
|
||||
|
||||
extra_mm_configs = functools.partial(
|
||||
filtered_configs,
|
||||
configs=extra_mm_platform_configs,
|
||||
)
|
||||
|
||||
int8_mm_configs = functools.partial(
|
||||
filtered_configs,
|
||||
configs=int8_platform_configs,
|
||||
)
|
||||
|
||||
mixed_mm_configs = functools.partial(
|
||||
filtered_configs,
|
||||
configs=mixed_mm_platform_configs,
|
||||
)
|
||||
|
||||
persistent_mm_configs = functools.partial(
|
||||
filtered_configs,
|
||||
configs=persistent_mm_platform_configs,
|
||||
)
|
||||
|
||||
scaled_mm_configs = functools.partial(
|
||||
filtered_configs,
|
||||
configs=scaled_mm_platform_configs,
|
||||
)
|
||||
|
||||
scaled_persistent_mm_configs = functools.partial(
|
||||
filtered_configs,
|
||||
configs=scaled_persistent_mm_platform_configs,
|
||||
)
|
||||
|
||||
|
||||
def mm_grid(m, n, meta):
|
||||
"""
|
||||
The CUDA grid size for matmul triton templates.
|
||||
|
|
@ -115,15 +530,6 @@ def mm_args(
|
|||
return [m, n, k, layout, mat1, mat2, *others]
|
||||
|
||||
|
||||
def mm_config_kwargs(device, exclude_condition):
|
||||
if device == "cpu":
|
||||
return {
|
||||
"scale": 0.5,
|
||||
"exclude": exclude_condition,
|
||||
}
|
||||
return {}
|
||||
|
||||
|
||||
def addmm_epilogue(dtype, alpha, beta):
|
||||
def epilogue(acc, bias):
|
||||
if alpha != 1:
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
|
||||
import torch
|
||||
|
||||
from .. import ir
|
||||
from ..lowering import lowerings
|
||||
from ..select_algorithm import (
|
||||
autotune_select_algorithm,
|
||||
|
|
@ -112,14 +112,101 @@ mm_plus_mm_template = TritonTemplate(
|
|||
)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def mm_configs():
|
||||
import triton
|
||||
|
||||
# List of dictionaries to store the kernel configs. Configs that evaluate to true
|
||||
# will be utilised on the target platform
|
||||
mm_triton_configs = [
|
||||
{
|
||||
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
|
||||
"num_stages": 2,
|
||||
"num_warps": 4,
|
||||
"cond": True,
|
||||
},
|
||||
{
|
||||
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
|
||||
"num_stages": 3,
|
||||
"num_warps": 8,
|
||||
"cond": True,
|
||||
},
|
||||
{
|
||||
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
|
||||
"num_stages": 4,
|
||||
"num_warps": 16,
|
||||
"cond": True,
|
||||
},
|
||||
{
|
||||
"config": {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32},
|
||||
"num_stages": 4,
|
||||
"num_warps": 8,
|
||||
"cond": True,
|
||||
},
|
||||
{
|
||||
"config": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32},
|
||||
"num_stages": 4,
|
||||
"num_warps": 8,
|
||||
"cond": True,
|
||||
},
|
||||
{
|
||||
"config": {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32},
|
||||
"num_stages": 1,
|
||||
"num_warps": 8,
|
||||
"cond": True,
|
||||
},
|
||||
{
|
||||
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64},
|
||||
"num_stages": 1,
|
||||
"num_warps": 8,
|
||||
"cond": True,
|
||||
},
|
||||
{
|
||||
"config": {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 128},
|
||||
"num_stages": 1,
|
||||
"num_warps": 8,
|
||||
"cond": torch.version.hip is None,
|
||||
},
|
||||
{
|
||||
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 16},
|
||||
"num_stages": 2,
|
||||
"num_warps": 4,
|
||||
"cond": True,
|
||||
},
|
||||
{
|
||||
"config": {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 16},
|
||||
"num_stages": 1,
|
||||
"num_warps": 2,
|
||||
"cond": True,
|
||||
},
|
||||
]
|
||||
|
||||
# Filter out configs in which cond evaluates to true
|
||||
# On ROCm convert num_stages to 1 as pipelining provides no benefit
|
||||
if torch.version.hip:
|
||||
filtered_configs = [
|
||||
triton.Config(c["config"], num_stages=1, num_warps=c["num_warps"])
|
||||
for c in mm_triton_configs
|
||||
if c["cond"]
|
||||
]
|
||||
else:
|
||||
filtered_configs = [
|
||||
triton.Config(
|
||||
c["config"], num_stages=c["num_stages"], num_warps=c["num_warps"]
|
||||
)
|
||||
for c in mm_triton_configs
|
||||
if c["cond"]
|
||||
]
|
||||
|
||||
return filtered_configs
|
||||
|
||||
|
||||
def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
|
||||
"""
|
||||
Computes mm(mat1, mat2) + mm(mat3, mat4)
|
||||
"""
|
||||
m1, n1, k1, layout1, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
|
||||
m2, n2, _, layout2, mat3, mat4 = mm_args(mat3, mat4, layout=layout)
|
||||
device_type = ir.get_device_type(mat1)
|
||||
|
||||
# Optimization is optional, because we can always just not do the fusion
|
||||
if (
|
||||
m1 * n1 == 0
|
||||
|
|
@ -144,9 +231,6 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
|
|||
if use_aten_gemm_kernels()
|
||||
else []
|
||||
)
|
||||
|
||||
mm_configs = V.choices.get_mm_plus_mm_configs(device_type)
|
||||
|
||||
if use_triton_template(layout1):
|
||||
for config in mm_configs():
|
||||
# see https://github.com/openai/triton/issues/1298
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from torch.utils._triton import has_triton_tma_device
|
|||
|
||||
from .. import config as inductor_config
|
||||
from ..config import triton as triton_config
|
||||
from ..ir import _IntLike, ChoiceCaller, get_device_type, Layout, StorageBox, TensorBox
|
||||
from ..ir import _IntLike, ChoiceCaller, Layout, StorageBox, TensorBox
|
||||
from ..lowering import add_layout_constraint, constrain_to_fx_strides, register_lowering
|
||||
from ..select_algorithm import (
|
||||
autotune_select_algorithm,
|
||||
|
|
@ -27,8 +27,14 @@ from ..utils import (
|
|||
use_ck_gemm_template,
|
||||
use_triton_template,
|
||||
)
|
||||
from ..virtualized import V
|
||||
from .mm_common import _is_static_problem, mm_args, mm_grid, persistent_mm_grid
|
||||
from .mm_common import (
|
||||
_is_static_problem,
|
||||
mm_args,
|
||||
mm_grid,
|
||||
persistent_mm_grid,
|
||||
scaled_mm_configs,
|
||||
scaled_persistent_mm_configs,
|
||||
)
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -509,7 +515,6 @@ def tuned_scaled_mm(
|
|||
m, n, k, layout, mat_a, mat_b = mm_args(
|
||||
mat_a, mat_b, layout=layout, out_dtype=out_dtype
|
||||
)
|
||||
device_type = get_device_type(mat_a)
|
||||
|
||||
check_supported_striding(mat_a, mat_b)
|
||||
|
||||
|
|
@ -535,11 +540,6 @@ def tuned_scaled_mm(
|
|||
|
||||
_, is_nonzero = _is_static_problem(layout)
|
||||
|
||||
scaled_mm_configs = V.choices.get_scaled_mm_configs(device_type)
|
||||
scaled_persistent_mm_configs = V.choices.get_scaled_persistent_mm_configs(
|
||||
device_type
|
||||
)
|
||||
|
||||
if is_nonzero and use_triton_template(layout, enable_float8=True):
|
||||
if use_persistent_tma(k, bias is not None):
|
||||
for config in scaled_persistent_mm_configs(m, n, k):
|
||||
|
|
|
|||
|
|
@ -2,10 +2,8 @@
|
|||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .. import ir
|
||||
from ..select_algorithm import autotune_select_algorithm, TritonTemplate
|
||||
from ..virtualized import V
|
||||
from .mm_common import mm_args, mm_grid, mm_options
|
||||
from .mm_common import mm_args, mm_configs, mm_grid, mm_options
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -77,13 +75,8 @@ uint4x2_mixed_mm_template = TritonTemplate(
|
|||
|
||||
def tuned_uint4x2_mixed_mm(mat1, mat2, mat2_mm_shape, mat2_dtype):
|
||||
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=None, use_4x2_dim=True)
|
||||
device_type = ir.get_device_type(mat1)
|
||||
|
||||
choices: list[ChoiceCaller] = []
|
||||
b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "")
|
||||
|
||||
mm_configs = V.choices.get_base_mm_configs(device_type)
|
||||
|
||||
for config in mm_configs(m, n, k):
|
||||
uint4x2_mixed_mm_template.maybe_append_choice(
|
||||
choices,
|
||||
|
|
|
|||
|
|
@ -36,8 +36,7 @@ def set_driver_to_gpu():
|
|||
|
||||
|
||||
def get_backend_options():
|
||||
from triton.runtime import driver
|
||||
|
||||
driver = triton.runtime.driver
|
||||
target = driver.active.get_current_target()
|
||||
backend = triton.compiler.compiler.make_backend(target)
|
||||
options = backend.parse_options(dict())
|
||||
|
|
|
|||
|
|
@ -1,613 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
from threading import Lock
|
||||
from typing import Any, Callable, Generator, List, Sequence, Tuple, Type, TYPE_CHECKING
|
||||
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from . import config
|
||||
from .utils import get_backend_num_stages
|
||||
from .virtualized import V
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from triton import Config as TritonConfig
|
||||
|
||||
|
||||
class BaseConfigSingleton(type):
|
||||
"""
|
||||
Thread-safe implementation of single to be used in the config heuristic subclasses
|
||||
to ensure heavy __init__ calls are not repeatedly run
|
||||
"""
|
||||
|
||||
_instances: dict[Type[Any], Any] = {}
|
||||
_lock: Lock = Lock()
|
||||
|
||||
def __call__(
|
||||
cls: BaseConfigSingleton, *args: Any, **kwargs: Any
|
||||
) -> BaseConfigHeuristic:
|
||||
with cls._lock:
|
||||
if cls not in cls._instances:
|
||||
instance = super().__call__()
|
||||
cls._instances[cls] = instance
|
||||
return cls._instances[cls]
|
||||
|
||||
|
||||
Config = namedtuple(
|
||||
"Config", ["block_m", "block_n", "block_k", "num_stages", "num_warps"]
|
||||
)
|
||||
|
||||
|
||||
class BaseConfigHeuristic(metaclass=BaseConfigSingleton):
|
||||
"""
|
||||
Base class for mm_configs, device specific triton kernels config inherit from here
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# List of dictionaries to store the kernel configs. Configs that evaluate to true
|
||||
# will be utilised on the target platform. The configs are as follows:
|
||||
# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
|
||||
self.mm_configs = [
|
||||
Config(32, 32, 16, 1, 2),
|
||||
Config(32, 32, 128, 2, 4),
|
||||
Config(32, 64, 32, 5, 8),
|
||||
Config(64, 32, 32, 5, 8),
|
||||
Config(64, 32, 128, 5, 4),
|
||||
Config(64, 64, 16, 2, 4),
|
||||
Config(64, 64, 32, 2, 4),
|
||||
Config(64, 64, 64, 3, 8),
|
||||
Config(64, 64, 128, 5, 4),
|
||||
Config(64, 128, 32, 3, 4),
|
||||
Config(64, 128, 32, 4, 8),
|
||||
Config(64, 128, 64, 3, 4),
|
||||
Config(64, 128, 128, 4, 4),
|
||||
Config(128, 64, 32, 3, 4),
|
||||
Config(128, 64, 32, 4, 8),
|
||||
Config(128, 128, 32, 2, 8),
|
||||
Config(128, 128, 32, 3, 4),
|
||||
Config(128, 128, 64, 3, 4),
|
||||
Config(128, 128, 64, 5, 8),
|
||||
]
|
||||
|
||||
# Exhaustive search for mm configs
|
||||
self.exhaustive_configs = [
|
||||
Config(BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
|
||||
for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
|
||||
[16, 32, 64, 128, 256], repeat=3
|
||||
)
|
||||
for num_stages in [1, 2, 3, 4, 5]
|
||||
for num_warps in [2, 4, 8]
|
||||
]
|
||||
|
||||
# these are only used in tuned_mm when AutoHeuristic is enabled
|
||||
# the idea is that when AutoHeuristic collects data to learn a heuristic, more configs are autotuned
|
||||
# when the learned heuristic is used, the learned heuristic reduces the number of configs down to 10
|
||||
# which saves compilation time (since less configs are autotuned) and potentially increase performance
|
||||
# because the learned heuristic might predict a config that is not part mm_configs
|
||||
self.extra_mm_configs = [
|
||||
Config(16, 32, 16, 3, 2),
|
||||
Config(16, 32, 32, 4, 2),
|
||||
Config(16, 32, 32, 5, 2),
|
||||
Config(64, 64, 128, 3, 4),
|
||||
Config(128, 64, 32, 2, 2),
|
||||
Config(128, 64, 64, 3, 8),
|
||||
Config(128, 64, 128, 4, 8),
|
||||
Config(128, 128, 32, 4, 4),
|
||||
Config(128, 128, 64, 3, 8),
|
||||
Config(128, 128, 64, 5, 4),
|
||||
]
|
||||
|
||||
self.int8_mm_configs = [
|
||||
Config(64, 64, 32, 2, 4),
|
||||
Config(64, 128, 32, 3, 4),
|
||||
Config(128, 64, 32, 3, 4),
|
||||
Config(64, 128, 32, 4, 8),
|
||||
Config(128, 64, 32, 4, 8),
|
||||
Config(64, 32, 32, 5, 8),
|
||||
Config(32, 64, 32, 5, 8),
|
||||
Config(128, 128, 32, 2, 8),
|
||||
Config(64, 64, 64, 3, 8),
|
||||
Config(128, 256, 128, 3, 8),
|
||||
Config(256, 128, 128, 3, 8),
|
||||
]
|
||||
|
||||
self.mixed_mm_configs = [
|
||||
Config(16, 128, 256, 3, 4),
|
||||
Config(16, 128, 256, 5, 8),
|
||||
]
|
||||
|
||||
self.persistent_mm_configs = [
|
||||
Config(128, 256, 64, 3, 8),
|
||||
Config(128, 128, 64, 3, 8),
|
||||
Config(128, 128, 128, 3, 8),
|
||||
Config(128, 128, 128, 3, 4),
|
||||
Config(128, 128, 64, 4, 8),
|
||||
]
|
||||
|
||||
self.scaled_mm_configs = [
|
||||
Config(128, 256, 32, 3, 8),
|
||||
Config(256, 128, 32, 3, 8),
|
||||
Config(256, 64, 32, 4, 4),
|
||||
Config(64, 256, 32, 4, 4),
|
||||
Config(128, 128, 32, 4, 4),
|
||||
Config(128, 64, 32, 4, 4),
|
||||
Config(64, 128, 32, 4, 4),
|
||||
Config(128, 32, 32, 4, 4),
|
||||
Config(64, 32, 32, 5, 2),
|
||||
Config(256, 128, 128, 3, 8),
|
||||
Config(256, 64, 128, 4, 4),
|
||||
Config(64, 256, 128, 4, 4),
|
||||
Config(128, 128, 128, 4, 4),
|
||||
Config(128, 64, 64, 4, 4),
|
||||
Config(64, 128, 64, 4, 4),
|
||||
Config(128, 32, 64, 4, 4),
|
||||
Config(64, 32, 64, 5, 2),
|
||||
Config(16, 32, 32, 2, 2),
|
||||
Config(16, 64, 32, 2, 2),
|
||||
Config(16, 128, 32, 2, 4),
|
||||
Config(16, 256, 32, 2, 4),
|
||||
Config(16, 32, 64, 2, 2),
|
||||
Config(16, 64, 64, 2, 2),
|
||||
Config(16, 128, 64, 2, 4),
|
||||
Config(16, 256, 64, 2, 4),
|
||||
Config(32, 32, 32, 2, 2),
|
||||
Config(32, 64, 32, 2, 2),
|
||||
Config(32, 128, 32, 2, 4),
|
||||
Config(32, 256, 32, 2, 4),
|
||||
Config(32, 32, 64, 2, 2),
|
||||
Config(32, 64, 64, 2, 2),
|
||||
Config(32, 128, 64, 2, 4),
|
||||
Config(32, 256, 64, 2, 4),
|
||||
Config(16, 32, 32, 3, 2),
|
||||
Config(16, 64, 32, 3, 2),
|
||||
Config(16, 128, 32, 3, 4),
|
||||
Config(16, 256, 32, 3, 4),
|
||||
Config(16, 32, 64, 3, 2),
|
||||
Config(16, 64, 64, 3, 2),
|
||||
Config(16, 128, 64, 3, 4),
|
||||
Config(16, 256, 64, 3, 4),
|
||||
Config(32, 32, 32, 3, 2),
|
||||
Config(32, 64, 32, 3, 2),
|
||||
Config(32, 128, 32, 3, 4),
|
||||
Config(32, 256, 32, 3, 4),
|
||||
Config(32, 32, 64, 3, 2),
|
||||
Config(32, 64, 64, 3, 2),
|
||||
Config(32, 128, 64, 3, 4),
|
||||
Config(32, 256, 64, 3, 4),
|
||||
Config(16, 32, 32, 4, 2),
|
||||
Config(16, 64, 32, 4, 2),
|
||||
Config(16, 128, 32, 4, 4),
|
||||
Config(16, 256, 32, 4, 4),
|
||||
Config(16, 32, 64, 4, 2),
|
||||
Config(16, 64, 64, 4, 2),
|
||||
Config(16, 128, 64, 4, 4),
|
||||
Config(16, 256, 64, 4, 4),
|
||||
Config(32, 32, 32, 4, 2),
|
||||
Config(32, 64, 32, 4, 2),
|
||||
Config(32, 128, 32, 4, 4),
|
||||
Config(32, 256, 32, 4, 4),
|
||||
Config(32, 32, 64, 4, 2),
|
||||
Config(32, 64, 64, 4, 2),
|
||||
Config(32, 128, 64, 4, 4),
|
||||
Config(32, 256, 64, 4, 4),
|
||||
Config(16, 32, 32, 5, 2),
|
||||
Config(16, 64, 32, 5, 2),
|
||||
Config(16, 128, 32, 5, 4),
|
||||
Config(16, 256, 32, 5, 4),
|
||||
Config(16, 32, 64, 5, 2),
|
||||
Config(16, 64, 64, 5, 2),
|
||||
Config(16, 128, 64, 5, 4),
|
||||
Config(16, 256, 64, 5, 4),
|
||||
Config(32, 32, 32, 5, 2),
|
||||
Config(32, 64, 32, 5, 2),
|
||||
Config(32, 128, 32, 5, 4),
|
||||
Config(32, 256, 32, 5, 4),
|
||||
Config(32, 32, 64, 5, 2),
|
||||
Config(32, 64, 64, 5, 2),
|
||||
Config(32, 128, 64, 5, 4),
|
||||
Config(32, 256, 64, 5, 4),
|
||||
Config(16, 32, 32, 6, 2),
|
||||
Config(16, 64, 32, 6, 2),
|
||||
Config(16, 128, 32, 6, 4),
|
||||
Config(16, 256, 32, 6, 4),
|
||||
Config(16, 32, 64, 6, 2),
|
||||
Config(16, 64, 64, 6, 2),
|
||||
Config(16, 128, 64, 6, 4),
|
||||
Config(16, 256, 64, 6, 4),
|
||||
Config(32, 32, 32, 6, 2),
|
||||
Config(32, 64, 32, 6, 2),
|
||||
Config(32, 128, 32, 6, 4),
|
||||
Config(32, 256, 32, 6, 4),
|
||||
Config(32, 32, 64, 6, 2),
|
||||
Config(32, 64, 64, 6, 2),
|
||||
Config(32, 128, 64, 6, 4),
|
||||
Config(32, 256, 64, 6, 4),
|
||||
]
|
||||
|
||||
self.scaled_persistent_mm_configs = [
|
||||
Config(128, 128, 64, 3, 8),
|
||||
Config(128, 128, 128, 3, 8),
|
||||
Config(128, 128, 128, 4, 8),
|
||||
Config(128, 128, 128, 4, 4),
|
||||
Config(128, 128, 128, 3, 4),
|
||||
Config(128, 128, 128, 5, 4),
|
||||
Config(128, 128, 128, 5, 8),
|
||||
Config(128, 128, 128, 6, 8),
|
||||
Config(128, 128, 64, 4, 8),
|
||||
]
|
||||
|
||||
# TODO: Unify with other gemm patterns, mm_plus_mm currently follows
|
||||
# slightly different pattern than rest
|
||||
self.mm_plus_mm_configs = [
|
||||
Config(64, 64, 32, 2, 4),
|
||||
Config(64, 64, 32, 3, 8),
|
||||
Config(64, 64, 32, 4, 16),
|
||||
Config(64, 32, 32, 4, 8),
|
||||
Config(32, 64, 32, 4, 8),
|
||||
Config(128, 128, 32, 1, 8),
|
||||
Config(64, 64, 64, 1, 8),
|
||||
Config(32, 32, 128, 1, 8),
|
||||
Config(64, 64, 16, 2, 4),
|
||||
Config(32, 32, 16, 1, 2),
|
||||
]
|
||||
|
||||
self.conv_configs = [
|
||||
Config(64, 256, 16, 2, 4),
|
||||
Config(256, 64, 16, 2, 4),
|
||||
Config(1024, 16, 16, 1, 8),
|
||||
Config(128, 128, 32, 2, 8),
|
||||
Config(64, 64, 32, 2, 4),
|
||||
Config(64, 256, 32, 2, 8),
|
||||
Config(256, 64, 32, 2, 8),
|
||||
]
|
||||
|
||||
def _finalize_mm_configs(
|
||||
self,
|
||||
configs: List[Config],
|
||||
) -> Generator[TritonConfig, None, None]:
|
||||
"""
|
||||
Finalizes configs after scaling, applying additional constraints.
|
||||
"""
|
||||
used = OrderedSet[Config]()
|
||||
|
||||
max_mm_configs = config.test_configs.max_mm_configs
|
||||
|
||||
for block_m, block_n, block_k, num_stages, num_warps in configs:
|
||||
# Each warp computes a 16x16 tile = 256 elements
|
||||
num_warps = min(num_warps, block_m * block_n // 256)
|
||||
|
||||
if (block_m, block_n, block_k, num_stages, num_warps, 0) not in used and (
|
||||
max_mm_configs is None or len(used) < max_mm_configs
|
||||
):
|
||||
used.add(Config(block_m, block_n, block_k, num_stages, num_warps))
|
||||
yield self.triton_config(
|
||||
BLOCK_M=block_m,
|
||||
BLOCK_N=block_n,
|
||||
BLOCK_K=block_k,
|
||||
num_stages=num_stages,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
|
||||
def _scale_mm_configs(
|
||||
self,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
configs: Sequence[Config],
|
||||
scale: float,
|
||||
has_int8_tensor: bool,
|
||||
exclude: Callable[[int, int, int], bool],
|
||||
) -> List[Config]:
|
||||
"""
|
||||
Scales and filters matrix multiplication configs based on input size.
|
||||
"""
|
||||
from .runtime.runtime_utils import next_power_of_2
|
||||
|
||||
min_block_size = 16
|
||||
min_block_size_k = 32 if has_int8_tensor else 16
|
||||
|
||||
m = max(
|
||||
next_power_of_2(
|
||||
V.graph.sizevars.size_hint(
|
||||
m, fallback=config.unbacked_symint_fallback # type: ignore[arg-type]
|
||||
)
|
||||
),
|
||||
min_block_size,
|
||||
)
|
||||
n = max(
|
||||
next_power_of_2(
|
||||
V.graph.sizevars.size_hint(
|
||||
n, fallback=config.unbacked_symint_fallback # type: ignore[arg-type]
|
||||
)
|
||||
),
|
||||
min_block_size,
|
||||
)
|
||||
k = max(
|
||||
next_power_of_2(
|
||||
V.graph.sizevars.size_hint(
|
||||
k, fallback=config.unbacked_symint_fallback # type: ignore[arg-type]
|
||||
)
|
||||
),
|
||||
min_block_size_k,
|
||||
)
|
||||
|
||||
scaled_configs = []
|
||||
for c in configs:
|
||||
scaled_config = c._replace(
|
||||
block_m=max(min(int(c.block_m * scale), m), min_block_size),
|
||||
block_n=max(min(int(c.block_n * scale), n), min_block_size),
|
||||
block_k=max(min(int(c.block_k * scale), k), min_block_size_k),
|
||||
)
|
||||
|
||||
if not exclude(
|
||||
scaled_config.block_m, scaled_config.block_n, scaled_config.block_k
|
||||
):
|
||||
scaled_configs.append(scaled_config)
|
||||
|
||||
return scaled_configs
|
||||
|
||||
def preprocess_mm_configs(
|
||||
self,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
configs: Sequence[Config],
|
||||
has_int8_tensor: bool = False,
|
||||
scale: int = 1,
|
||||
exclude: Callable[[int, int, int], bool] = lambda m, n, k: False,
|
||||
) -> Generator[TritonConfig, None, None]:
|
||||
scaled_configs = self._scale_mm_configs(
|
||||
m, n, k, configs, scale, has_int8_tensor, exclude
|
||||
)
|
||||
return self._finalize_mm_configs(scaled_configs)
|
||||
|
||||
def triton_config(
|
||||
self, num_stages: int, num_warps: int, **kwargs: Any
|
||||
) -> TritonConfig:
|
||||
from triton import Config as TritonConfig # type: ignore[attr-defined]
|
||||
|
||||
return TritonConfig(kwargs, num_stages=num_stages, num_warps=num_warps)
|
||||
|
||||
def get_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
return partial(self.preprocess_mm_configs, configs=self.mm_configs)
|
||||
|
||||
def get_exhaustive_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
return partial(self.preprocess_mm_configs, configs=self.exhaustive_configs)
|
||||
|
||||
def get_extra_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
return partial(self.preprocess_mm_configs, configs=self.extra_mm_configs)
|
||||
|
||||
def get_int8_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
return partial(self.preprocess_mm_configs, configs=self.int8_mm_configs)
|
||||
|
||||
def get_mixed_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
mm_configs = (
|
||||
self.mm_configs + self.mixed_mm_configs
|
||||
if config.max_autotune_gemm_search_space == "EXHAUSTIVE"
|
||||
else self.mm_configs
|
||||
)
|
||||
return partial(self.preprocess_mm_configs, configs=mm_configs)
|
||||
|
||||
def get_persistent_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
return partial(self.preprocess_mm_configs, configs=self.persistent_mm_configs)
|
||||
|
||||
def get_scaled_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
return partial(self.preprocess_mm_configs, configs=self.scaled_mm_configs)
|
||||
|
||||
def get_scaled_persistent_mm_configs(
|
||||
self,
|
||||
) -> partial[Generator[TritonConfig, None, None]]:
|
||||
return partial(
|
||||
self.preprocess_mm_configs, configs=self.scaled_persistent_mm_configs
|
||||
)
|
||||
|
||||
def get_mm_plus_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
return partial(self._finalize_mm_configs, configs=self.mm_plus_mm_configs)
|
||||
|
||||
def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
return partial(self.preprocess_mm_configs, configs=self.conv_configs)
|
||||
|
||||
def generate_mixed_mm_config(self, m: int, n: int, k: int) -> TritonConfig:
|
||||
if m <= 16 and n >= 4096 and k >= 4096:
|
||||
return self.triton_config(
|
||||
BLOCK_M=16,
|
||||
BLOCK_N=64,
|
||||
BLOCK_K=128,
|
||||
num_stages=5,
|
||||
num_warps=4,
|
||||
)
|
||||
elif m > 16 and m <= 32 and n >= 4096 and k >= 4096:
|
||||
return self.triton_config(
|
||||
BLOCK_M=32,
|
||||
BLOCK_N=32,
|
||||
BLOCK_K=128,
|
||||
num_stages=5,
|
||||
num_warps=4,
|
||||
)
|
||||
elif m > 32 and m <= 64 and n >= 4096 and k >= 4096:
|
||||
return self.triton_config(
|
||||
BLOCK_M=64,
|
||||
BLOCK_N=32,
|
||||
BLOCK_K=128,
|
||||
num_stages=5,
|
||||
num_warps=4,
|
||||
)
|
||||
|
||||
|
||||
class CPUConfigHeuristic(BaseConfigHeuristic):
|
||||
pass
|
||||
|
||||
|
||||
class CUDAConfigHeuristic(BaseConfigHeuristic):
|
||||
pass
|
||||
|
||||
|
||||
class ROCmConfigHeuristic(BaseConfigHeuristic):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.default_num_stages = get_backend_num_stages()
|
||||
|
||||
# Exhaustive search for mm configs
|
||||
self.exhaustive_configs = [
|
||||
Config(BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
|
||||
for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
|
||||
[16, 32, 64, 128, 256], repeat=3
|
||||
)
|
||||
for num_stages in [1, self.default_num_stages]
|
||||
for num_warps in [4, 8]
|
||||
]
|
||||
|
||||
def _filter_configs(
|
||||
self, configs: List[Config], new_num_stages: int
|
||||
) -> List[Config]:
|
||||
filtered_configs = [
|
||||
c._replace(num_stages=self.default_num_stages) for c in configs
|
||||
]
|
||||
return filtered_configs
|
||||
|
||||
def _finalize_mm_configs(
|
||||
self,
|
||||
configs: List[Config],
|
||||
) -> Generator[TritonConfig, None, None]:
|
||||
used = OrderedSet[Tuple[Config, int]]()
|
||||
|
||||
max_mm_configs = config.test_configs.max_mm_configs
|
||||
for block_m, block_n, block_k, num_stages, num_warps in configs:
|
||||
# each warp computes 16x16 tile = 256
|
||||
num_warps = min(num_warps, block_m * block_n // 256)
|
||||
|
||||
for matrix_instr_nonkdim in [0, 16]:
|
||||
if matrix_instr_nonkdim != 0 and (
|
||||
block_m % matrix_instr_nonkdim != 0
|
||||
or block_n % matrix_instr_nonkdim != 0
|
||||
):
|
||||
# block_m and block_n must be a multiple of matrix_instr_nonkdim
|
||||
continue
|
||||
if (
|
||||
block_m,
|
||||
block_n,
|
||||
block_k,
|
||||
num_stages,
|
||||
num_warps,
|
||||
matrix_instr_nonkdim,
|
||||
) not in used and (
|
||||
max_mm_configs is None or len(used) < max_mm_configs
|
||||
):
|
||||
used.add(
|
||||
(
|
||||
Config(
|
||||
block_m,
|
||||
block_n,
|
||||
block_k,
|
||||
num_stages,
|
||||
num_warps,
|
||||
),
|
||||
matrix_instr_nonkdim,
|
||||
)
|
||||
)
|
||||
|
||||
yield self.triton_config(
|
||||
BLOCK_M=block_m,
|
||||
BLOCK_N=block_n,
|
||||
BLOCK_K=block_k,
|
||||
num_stages=num_stages,
|
||||
num_warps=num_warps,
|
||||
matrix_instr_nonkdim=matrix_instr_nonkdim,
|
||||
)
|
||||
|
||||
def get_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
filtered_configs = self._filter_configs(
|
||||
self.mm_configs, self.default_num_stages
|
||||
)
|
||||
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
||||
|
||||
def get_exhaustive_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
filtered_configs = self._filter_configs(
|
||||
self.exhaustive_configs, self.default_num_stages
|
||||
)
|
||||
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
||||
|
||||
def get_extra_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
filtered_configs = self._filter_configs(
|
||||
self.extra_mm_configs, self.default_num_stages
|
||||
)
|
||||
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
||||
|
||||
def get_int8_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
filtered_configs = self._filter_configs(
|
||||
self.int8_mm_configs, self.default_num_stages
|
||||
)
|
||||
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
||||
|
||||
def get_mixed_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
mm_configs = (
|
||||
self.mm_configs + self.mixed_mm_configs
|
||||
if config.max_autotune_gemm_search_space == "EXHAUSTIVE"
|
||||
else self.mm_configs
|
||||
)
|
||||
filtered_configs = self._filter_configs(mm_configs, self.default_num_stages)
|
||||
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
||||
|
||||
def get_persistent_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
filtered_configs = self._filter_configs(
|
||||
self.persistent_mm_configs, self.default_num_stages
|
||||
)
|
||||
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
||||
|
||||
def get_scaled_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
filtered_configs = self._filter_configs(
|
||||
self.scaled_mm_configs, self.default_num_stages
|
||||
)
|
||||
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
||||
|
||||
def get_scaled_persistent_mm_configs(
|
||||
self,
|
||||
) -> partial[Generator[TritonConfig, None, None]]:
|
||||
filtered_configs = self._filter_configs(
|
||||
self.scaled_persistent_mm_configs, self.default_num_stages
|
||||
)
|
||||
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
||||
|
||||
def get_mm_plus_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
filtered_configs = self._filter_configs(self.mm_plus_mm_configs, 1)
|
||||
return partial(self._finalize_mm_configs, configs=filtered_configs)
|
||||
|
||||
def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||
filtered_configs = self._filter_configs(
|
||||
self.conv_configs, self.default_num_stages
|
||||
)
|
||||
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
||||
|
||||
def generate_mixed_mm_config(self, m: int, n: int, k: int) -> TritonConfig:
|
||||
if m <= 16 and n >= 4096 and k >= 4096:
|
||||
return self.triton_config(
|
||||
BLOCK_M=16,
|
||||
BLOCK_N=64,
|
||||
BLOCK_K=128,
|
||||
num_stages=self.default_num_stages,
|
||||
num_warps=4,
|
||||
)
|
||||
elif m > 16 and m <= 32 and n >= 4096 and k >= 4096:
|
||||
return self.triton_config(
|
||||
BLOCK_M=32,
|
||||
BLOCK_N=32,
|
||||
BLOCK_K=128,
|
||||
num_stages=self.default_num_stages,
|
||||
num_warps=4,
|
||||
)
|
||||
elif m > 32 and m <= 64 and n >= 4096 and k >= 4096:
|
||||
return self.triton_config(
|
||||
BLOCK_M=64,
|
||||
BLOCK_N=32,
|
||||
BLOCK_K=128,
|
||||
num_stages=self.default_num_stages,
|
||||
num_warps=4,
|
||||
)
|
||||
|
||||
|
||||
class XPUConfigHeuristic(BaseConfigHeuristic):
|
||||
pass
|
||||
Loading…
Reference in New Issue
Block a user