mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Reland "Introduce new template heuristic for triton autotune configs" (#147452)
This change was reverted in https://github.com/pytorch/pytorch/pull/147388 for regressing an internal workload. I have removed the additional ir.device_type calls in mm_scaled and unpack_mixed_mm.py which could be contributing to the additional compile time. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147452 Approved by: https://github.com/jansel
This commit is contained in:
parent
7336b76bcc
commit
32299e5f9a
|
|
@ -1,20 +1,33 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import typing
|
import typing
|
||||||
from typing import Any, TYPE_CHECKING
|
from typing import Any, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
import sympy
|
import sympy
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from . import config
|
from . import config
|
||||||
from .codecache import write_text
|
from .codecache import write_text
|
||||||
from .metrics import get_metric_table, is_metric_table_enabled
|
from .metrics import get_metric_table, is_metric_table_enabled
|
||||||
from .runtime.hints import DeviceProperties, ReductionHint
|
from .runtime.hints import DeviceProperties, ReductionHint
|
||||||
from .scheduler import BaseSchedulerNode, Scheduler, WhyNoFuse
|
from .scheduler import BaseSchedulerNode, Scheduler, WhyNoFuse
|
||||||
|
from .template_heuristics import (
|
||||||
|
BaseConfigHeuristic,
|
||||||
|
CPUConfigHeuristic,
|
||||||
|
CUDAConfigHeuristic,
|
||||||
|
ROCmConfigHeuristic,
|
||||||
|
XPUConfigHeuristic,
|
||||||
|
)
|
||||||
from .virtualized import V
|
from .virtualized import V
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import torch
|
from collections.abc import Generator
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from triton import Config as TritonConfig
|
||||||
|
|
||||||
from torch.utils._ordered_set import OrderedSet
|
from torch.utils._ordered_set import OrderedSet
|
||||||
|
|
||||||
from .codegen.simd_kernel_features import SIMDKernelFeatures
|
from .codegen.simd_kernel_features import SIMDKernelFeatures
|
||||||
|
|
@ -40,6 +53,80 @@ class InductorChoices:
|
||||||
torch._inductor.virtualized.V.set_choices_handler(MyHeuristics())
|
torch._inductor.virtualized.V.set_choices_handler(MyHeuristics())
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def get_config_heuristics(
|
||||||
|
self, device_type: Optional[str] = "cuda"
|
||||||
|
) -> BaseConfigHeuristic:
|
||||||
|
if device_type == "cuda":
|
||||||
|
if torch.version.hip is None:
|
||||||
|
return CUDAConfigHeuristic()
|
||||||
|
else:
|
||||||
|
return ROCmConfigHeuristic()
|
||||||
|
elif device_type == "xpu":
|
||||||
|
return XPUConfigHeuristic()
|
||||||
|
elif device_type == "cpu":
|
||||||
|
return CPUConfigHeuristic()
|
||||||
|
else:
|
||||||
|
return BaseConfigHeuristic()
|
||||||
|
|
||||||
|
# GEMM configs
|
||||||
|
def get_base_mm_configs(
|
||||||
|
self, device_type: Optional[str] = "cuda"
|
||||||
|
) -> partial[Generator[TritonConfig, None, None]]:
|
||||||
|
mm_heuristics = self.get_config_heuristics(device_type)
|
||||||
|
if config.max_autotune_gemm_search_space != "EXHAUSTIVE":
|
||||||
|
return mm_heuristics.get_mm_configs()
|
||||||
|
else:
|
||||||
|
return mm_heuristics.get_exhaustive_mm_configs()
|
||||||
|
|
||||||
|
def get_extra_mm_configs(
|
||||||
|
self, device_type: Optional[str] = "cuda"
|
||||||
|
) -> partial[Generator[TritonConfig, None, None]]:
|
||||||
|
mm_heuristics = self.get_config_heuristics(device_type)
|
||||||
|
return mm_heuristics.get_extra_mm_configs()
|
||||||
|
|
||||||
|
def get_int8_mm_configs(
|
||||||
|
self, device_type: Optional[str] = "cuda"
|
||||||
|
) -> partial[Generator[TritonConfig, None, None]]:
|
||||||
|
mm_heuristics = self.get_config_heuristics(device_type)
|
||||||
|
return mm_heuristics.get_int8_mm_configs()
|
||||||
|
|
||||||
|
def get_mixed_mm_configs(
|
||||||
|
self, device_type: Optional[str] = "cuda"
|
||||||
|
) -> partial[Generator[TritonConfig, None, None]]:
|
||||||
|
mm_heuristics = self.get_config_heuristics(device_type)
|
||||||
|
return mm_heuristics.get_mixed_mm_configs()
|
||||||
|
|
||||||
|
def get_persistent_mm_configs(
|
||||||
|
self, device_type: Optional[str] = "cuda"
|
||||||
|
) -> partial[Generator[TritonConfig, None, None]]:
|
||||||
|
mm_heuristics = self.get_config_heuristics(device_type)
|
||||||
|
return mm_heuristics.get_persistent_mm_configs()
|
||||||
|
|
||||||
|
def get_scaled_mm_configs(
|
||||||
|
self, device_type: Optional[str] = "cuda"
|
||||||
|
) -> partial[Generator[TritonConfig, None, None]]:
|
||||||
|
mm_heuristics = self.get_config_heuristics(device_type)
|
||||||
|
return mm_heuristics.get_scaled_mm_configs()
|
||||||
|
|
||||||
|
def get_scaled_persistent_mm_configs(
|
||||||
|
self, device_type: Optional[str] = "cuda"
|
||||||
|
) -> partial[Generator[TritonConfig, None, None]]:
|
||||||
|
mm_heuristics = self.get_config_heuristics(device_type)
|
||||||
|
return mm_heuristics.get_scaled_persistent_mm_configs()
|
||||||
|
|
||||||
|
def get_mm_plus_mm_configs(
|
||||||
|
self, device_type: Optional[str] = "cuda"
|
||||||
|
) -> partial[Generator[TritonConfig, None, None]]:
|
||||||
|
mm_heuristics = self.get_config_heuristics(device_type)
|
||||||
|
return mm_heuristics.get_mm_plus_mm_configs()
|
||||||
|
|
||||||
|
# Conv configs
|
||||||
|
def get_conv_configs(
|
||||||
|
self, device_type: Optional[str] = "cuda"
|
||||||
|
) -> partial[Generator[TritonConfig, None, None]]:
|
||||||
|
conv_heuristics = self.get_config_heuristics(device_type)
|
||||||
|
return conv_heuristics.get_conv_configs()
|
||||||
|
|
||||||
def triton_kernel_kwargs(
|
def triton_kernel_kwargs(
|
||||||
self,
|
self,
|
||||||
kernel_cls: type[TritonKernel],
|
kernel_cls: type[TritonKernel],
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ from .mm_common import (
|
||||||
_is_static_problem,
|
_is_static_problem,
|
||||||
addmm_epilogue,
|
addmm_epilogue,
|
||||||
mm_args,
|
mm_args,
|
||||||
mm_configs,
|
mm_config_kwargs,
|
||||||
mm_options,
|
mm_options,
|
||||||
should_fallback_to_aten,
|
should_fallback_to_aten,
|
||||||
)
|
)
|
||||||
|
|
@ -46,12 +46,6 @@ def _is_large_block_for_cpu(m, n, k):
|
||||||
return m * n > 2**12
|
return m * n > 2**12
|
||||||
|
|
||||||
|
|
||||||
def bmm_configs(m, n, k, *, device_type):
|
|
||||||
if device_type == "cpu":
|
|
||||||
return mm_configs(m, n, k, scale=0.5, exclude=_is_large_block_for_cpu)
|
|
||||||
return mm_configs(m, n, k)
|
|
||||||
|
|
||||||
|
|
||||||
bmm_template = TritonTemplate(
|
bmm_template = TritonTemplate(
|
||||||
name="bmm",
|
name="bmm",
|
||||||
grid=bmm_grid,
|
grid=bmm_grid,
|
||||||
|
|
@ -184,8 +178,14 @@ def tuned_bmm(mat1, mat2, *, layout=None):
|
||||||
|
|
||||||
# options to tune from
|
# options to tune from
|
||||||
choices = [aten_bmm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
|
choices = [aten_bmm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
|
||||||
|
|
||||||
|
device_type = ir.get_device_type(mat1)
|
||||||
|
bmm_configs = V.choices.get_base_mm_configs(device_type)
|
||||||
|
|
||||||
if use_triton_template(layout):
|
if use_triton_template(layout):
|
||||||
for config in bmm_configs(m, n, k, device_type=ir.get_device_type(mat1)):
|
for config in bmm_configs(
|
||||||
|
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
|
||||||
|
):
|
||||||
bmm_template.maybe_append_choice(
|
bmm_template.maybe_append_choice(
|
||||||
choices,
|
choices,
|
||||||
input_nodes=(mat1, mat2),
|
input_nodes=(mat1, mat2),
|
||||||
|
|
@ -239,8 +239,14 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
||||||
if use_aten_gemm_kernels()
|
if use_aten_gemm_kernels()
|
||||||
else []
|
else []
|
||||||
)
|
)
|
||||||
|
|
||||||
|
device_type = ir.get_device_type(mat1)
|
||||||
|
bmm_configs = V.choices.get_base_mm_configs(device_type)
|
||||||
|
|
||||||
if use_triton_template(layout):
|
if use_triton_template(layout):
|
||||||
for config in bmm_configs(m, n, k, device_type=ir.get_device_type(mat1)):
|
for config in bmm_configs(
|
||||||
|
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
|
||||||
|
):
|
||||||
bmm_template.maybe_append_choice(
|
bmm_template.maybe_append_choice(
|
||||||
choices,
|
choices,
|
||||||
input_nodes=(inp, mat1, mat2),
|
input_nodes=(inp, mat1, mat2),
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import cast, Optional, TYPE_CHECKING, TypedDict
|
from typing import Optional, TYPE_CHECKING, TypedDict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._inductor.codegen.rocm.ck_conv_template import CKGroupedConvFwdTemplate
|
from torch._inductor.codegen.rocm.ck_conv_template import CKGroupedConvFwdTemplate
|
||||||
|
|
@ -29,7 +29,7 @@ from ..utils import (
|
||||||
use_triton_template,
|
use_triton_template,
|
||||||
)
|
)
|
||||||
from ..virtualized import V
|
from ..virtualized import V
|
||||||
from .mm_common import build_rocm_gemm_configs, filtered_configs
|
from .mm_common import mm_config_kwargs
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
@ -61,31 +61,6 @@ def conv3d_grid(n, c, d, h, w, meta, *, cdiv):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# List of dictionaries to store the kernel configs. Configs that evaluate to true
|
|
||||||
# will be utilised on the target platform
|
|
||||||
kernel_configs = [
|
|
||||||
# "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps"
|
|
||||||
{"config": (64, 256, 16, 2, 4), "cond": True},
|
|
||||||
{"config": (256, 64, 16, 2, 4), "cond": True},
|
|
||||||
{"config": (1024, 16, 16, 1, 8), "cond": True},
|
|
||||||
{"config": (128, 128, 32, 2, 8), "cond": True},
|
|
||||||
{"config": (64, 64, 32, 2, 4), "cond": True},
|
|
||||||
{"config": (64, 256, 32, 2, 8), "cond": True},
|
|
||||||
{"config": (256, 64, 32, 2, 8), "cond": True},
|
|
||||||
]
|
|
||||||
|
|
||||||
# Create filtered list of configs based on conv
|
|
||||||
platform_configs = tuple(
|
|
||||||
cast(tuple[int, int, int, int, int], config["config"])
|
|
||||||
for config in kernel_configs
|
|
||||||
if config["cond"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# On ROCm convert num_stages to 1 as pipelining provides no benefit
|
|
||||||
if torch.version.hip and torch.cuda.is_available():
|
|
||||||
platform_configs = build_rocm_gemm_configs(platform_configs)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_large_block_for_cpu(m, n, k):
|
def _is_large_block_for_cpu(m, n, k):
|
||||||
# Thresholds are experimentally determined to reduce Triton CPU compile times
|
# Thresholds are experimentally determined to reduce Triton CPU compile times
|
||||||
if m > 256 or n > 256 or k > 256:
|
if m > 256 or n > 256 or k > 256:
|
||||||
|
|
@ -93,19 +68,6 @@ def _is_large_block_for_cpu(m, n, k):
|
||||||
return m * n * k > 2**17
|
return m * n * k > 2**17
|
||||||
|
|
||||||
|
|
||||||
def conv_configs(m, n, k, *, device_type, **kwargs):
|
|
||||||
if device_type == "cpu":
|
|
||||||
return filtered_configs(
|
|
||||||
m,
|
|
||||||
n,
|
|
||||||
k,
|
|
||||||
configs=platform_configs,
|
|
||||||
scale=0.5,
|
|
||||||
exclude=_is_large_block_for_cpu,
|
|
||||||
)
|
|
||||||
return filtered_configs(m, n, k, configs=platform_configs)
|
|
||||||
|
|
||||||
|
|
||||||
LOOP_BODY_2D = """
|
LOOP_BODY_2D = """
|
||||||
idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H
|
idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H
|
||||||
idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W
|
idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W
|
||||||
|
|
@ -497,6 +459,8 @@ def convolution(
|
||||||
"groups": groups,
|
"groups": groups,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
device_type = ir.get_device_type(x)
|
||||||
|
|
||||||
if len(x.get_size()) == len(weight.get_size()) - 1:
|
if len(x.get_size()) == len(weight.get_size()) - 1:
|
||||||
# add batch dimension to simplify rest of function
|
# add batch dimension to simplify rest of function
|
||||||
return L[aten.squeeze](
|
return L[aten.squeeze](
|
||||||
|
|
@ -511,11 +475,7 @@ def convolution(
|
||||||
# Always convert conv1D to 2D for Intel GPU.
|
# Always convert conv1D to 2D for Intel GPU.
|
||||||
# Only conv2D can be converted to channel last layout,
|
# Only conv2D can be converted to channel last layout,
|
||||||
# which have much better performance.
|
# which have much better performance.
|
||||||
if (
|
if len(x.get_size()) == 3 and len(kernel_shape) == 1 and device_type == "xpu":
|
||||||
len(x.get_size()) == 3
|
|
||||||
and len(kernel_shape) == 1
|
|
||||||
and ir.get_device_type(x) == "xpu"
|
|
||||||
):
|
|
||||||
kwargs.update(
|
kwargs.update(
|
||||||
{
|
{
|
||||||
"stride": (1,) + stride,
|
"stride": (1,) + stride,
|
||||||
|
|
@ -564,7 +524,7 @@ def convolution(
|
||||||
):
|
):
|
||||||
return convert_1x1_conv_to_mm(x, weight, bias)
|
return convert_1x1_conv_to_mm(x, weight, bias)
|
||||||
|
|
||||||
if bias is not None and ir.get_device_type(x) != "cpu":
|
if bias is not None and device_type != "cpu":
|
||||||
# peel off the bias, cudnn is slower with it
|
# peel off the bias, cudnn is slower with it
|
||||||
result = convolution(x, weight, None, **kwargs)
|
result = convolution(x, weight, None, **kwargs)
|
||||||
return L[aten.add](
|
return L[aten.add](
|
||||||
|
|
@ -639,11 +599,13 @@ def convolution(
|
||||||
):
|
):
|
||||||
choices.append(aten_conv1x1_via_mm.bind(args, layout))
|
choices.append(aten_conv1x1_via_mm.bind(args, layout))
|
||||||
|
|
||||||
|
conv_configs = V.choices.get_conv_configs(device_type)
|
||||||
|
|
||||||
for cfg in conv_configs(
|
for cfg in conv_configs(
|
||||||
sympy_product([x.get_size()[0], *x.get_size()[2:]]),
|
sympy_product([x.get_size()[0], *x.get_size()[2:]]),
|
||||||
out_chan,
|
out_chan,
|
||||||
in_chan,
|
in_chan,
|
||||||
device_type=ir.get_device_type(x),
|
**mm_config_kwargs(device_type, _is_large_block_for_cpu),
|
||||||
):
|
):
|
||||||
if ndim == 2:
|
if ndim == 2:
|
||||||
conv2d_template.maybe_append_choice(
|
conv2d_template.maybe_append_choice(
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,6 @@ from ..select_algorithm import (
|
||||||
TritonTemplate,
|
TritonTemplate,
|
||||||
)
|
)
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
get_gpu_shared_memory,
|
|
||||||
get_tma_workspace_arg,
|
get_tma_workspace_arg,
|
||||||
use_aten_gemm_kernels,
|
use_aten_gemm_kernels,
|
||||||
use_ck_gemm_template,
|
use_ck_gemm_template,
|
||||||
|
|
@ -41,17 +40,13 @@ from ..utils import (
|
||||||
from .mm_common import (
|
from .mm_common import (
|
||||||
_is_static_problem,
|
_is_static_problem,
|
||||||
addmm_epilogue,
|
addmm_epilogue,
|
||||||
extra_mm_configs,
|
|
||||||
int8_mm_configs,
|
|
||||||
mm_args,
|
mm_args,
|
||||||
mm_configs,
|
mm_config_kwargs,
|
||||||
mm_grid,
|
mm_grid,
|
||||||
mm_options,
|
mm_options,
|
||||||
persistent_mm_configs,
|
|
||||||
persistent_mm_grid,
|
persistent_mm_grid,
|
||||||
persistent_mm_options,
|
persistent_mm_options,
|
||||||
should_fallback_to_aten,
|
should_fallback_to_aten,
|
||||||
triton_config,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -341,15 +336,6 @@ def _is_large_block_for_cpu(m, n, k):
|
||||||
return m * n > 2**13
|
return m * n > 2**13
|
||||||
|
|
||||||
|
|
||||||
def mm_config_kwargs(device):
|
|
||||||
if device == "cpu":
|
|
||||||
return {
|
|
||||||
"scale": 0.5,
|
|
||||||
"exclude": _is_large_block_for_cpu,
|
|
||||||
}
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1):
|
def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1):
|
||||||
"""
|
"""
|
||||||
Giving torch.addmm a 1D tensor calls a different (faster) cublasLt
|
Giving torch.addmm a 1D tensor calls a different (faster) cublasLt
|
||||||
|
|
@ -367,6 +353,7 @@ aten_bias_addmm = ExternKernelChoice(bias_addmm, None)
|
||||||
@register_lowering(aten.mm, type_promotion_kind=None)
|
@register_lowering(aten.mm, type_promotion_kind=None)
|
||||||
def tuned_mm(mat1, mat2, *, layout=None):
|
def tuned_mm(mat1, mat2, *, layout=None):
|
||||||
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
|
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
|
||||||
|
device_type = ir.get_device_type(mat1)
|
||||||
name = "mm"
|
name = "mm"
|
||||||
|
|
||||||
# below is for getting an overview logging info of inductor mms
|
# below is for getting an overview logging info of inductor mms
|
||||||
|
|
@ -392,8 +379,15 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
||||||
[aten_mm.bind((mat1, mat2), aten_layout)] if use_aten_gemm_kernels() else []
|
[aten_mm.bind((mat1, mat2), aten_layout)] if use_aten_gemm_kernels() else []
|
||||||
)
|
)
|
||||||
static_shape, is_nonzero = _is_static_problem(layout)
|
static_shape, is_nonzero = _is_static_problem(layout)
|
||||||
|
|
||||||
|
mm_configs = V.choices.get_base_mm_configs(device_type)
|
||||||
|
persistent_mm_configs = V.choices.get_persistent_mm_configs(device_type)
|
||||||
|
extra_mm_configs = V.choices.get_extra_mm_configs(device_type)
|
||||||
|
|
||||||
if is_nonzero and use_triton_template(layout):
|
if is_nonzero and use_triton_template(layout):
|
||||||
for config in mm_configs(m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))):
|
for config in mm_configs(
|
||||||
|
m, n, k, *mm_config_kwargs(device_type, _is_large_block_for_cpu)
|
||||||
|
):
|
||||||
mm_template.maybe_append_choice(
|
mm_template.maybe_append_choice(
|
||||||
choices,
|
choices,
|
||||||
input_nodes=(mat1, mat2),
|
input_nodes=(mat1, mat2),
|
||||||
|
|
@ -402,7 +396,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
||||||
)
|
)
|
||||||
if use_triton_tma_template(mat1, mat2):
|
if use_triton_tma_template(mat1, mat2):
|
||||||
for config in persistent_mm_configs(
|
for config in persistent_mm_configs(
|
||||||
m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))
|
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
|
||||||
):
|
):
|
||||||
persistent_tma_mm_template.maybe_append_choice(
|
persistent_tma_mm_template.maybe_append_choice(
|
||||||
choices,
|
choices,
|
||||||
|
|
@ -441,7 +435,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
||||||
always_included.append("extern_mm")
|
always_included.append("extern_mm")
|
||||||
num_choices_before_extra_configs = len(choices)
|
num_choices_before_extra_configs = len(choices)
|
||||||
for config in extra_mm_configs(
|
for config in extra_mm_configs(
|
||||||
m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))
|
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
|
||||||
):
|
):
|
||||||
mm_template.maybe_append_choice(
|
mm_template.maybe_append_choice(
|
||||||
choices,
|
choices,
|
||||||
|
|
@ -503,6 +497,8 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
|
||||||
layout,
|
layout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
device_type = ir.get_device_type(mat1)
|
||||||
|
|
||||||
static_shape, is_nonzero = _is_static_problem(layout)
|
static_shape, is_nonzero = _is_static_problem(layout)
|
||||||
use_cutlass = static_shape and is_nonzero and use_cutlass_template(layout, m, n, k)
|
use_cutlass = static_shape and is_nonzero and use_cutlass_template(layout, m, n, k)
|
||||||
|
|
||||||
|
|
@ -514,9 +510,12 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
|
||||||
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
|
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
|
||||||
choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True
|
choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
int8_mm_configs = V.choices.get_int8_mm_configs(device_type)
|
||||||
|
|
||||||
if is_nonzero and use_triton_template(layout, enable_int32=True):
|
if is_nonzero and use_triton_template(layout, enable_int32=True):
|
||||||
for config in int8_mm_configs(
|
for config in int8_mm_configs(
|
||||||
m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))
|
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
|
||||||
):
|
):
|
||||||
mm_template.maybe_append_choice(
|
mm_template.maybe_append_choice(
|
||||||
choices,
|
choices,
|
||||||
|
|
@ -534,6 +533,7 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
|
||||||
@register_lowering(aten.addmm, type_promotion_kind=None)
|
@register_lowering(aten.addmm, type_promotion_kind=None)
|
||||||
def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
||||||
ordered_kwargs_for_cpp_kernel = ("beta", "alpha")
|
ordered_kwargs_for_cpp_kernel = ("beta", "alpha")
|
||||||
|
device_type = ir.get_device_type(mat1)
|
||||||
m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout)
|
m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout)
|
||||||
static_shape, is_nonzero = _is_static_problem(layout)
|
static_shape, is_nonzero = _is_static_problem(layout)
|
||||||
|
|
||||||
|
|
@ -599,8 +599,13 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mm_configs = V.choices.get_base_mm_configs(device_type)
|
||||||
|
persistent_mm_configs = V.choices.get_persistent_mm_configs(device_type)
|
||||||
|
|
||||||
if is_nonzero and use_triton_template(layout):
|
if is_nonzero and use_triton_template(layout):
|
||||||
for config in mm_configs(m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))):
|
for config in mm_configs(
|
||||||
|
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
|
||||||
|
):
|
||||||
mm_template.maybe_append_choice(
|
mm_template.maybe_append_choice(
|
||||||
choices,
|
choices,
|
||||||
input_nodes=(inp_expanded, mat1, mat2),
|
input_nodes=(inp_expanded, mat1, mat2),
|
||||||
|
|
@ -612,7 +617,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
||||||
|
|
||||||
if use_triton_tma_template(mat1, mat2):
|
if use_triton_tma_template(mat1, mat2):
|
||||||
for config in persistent_mm_configs(
|
for config in persistent_mm_configs(
|
||||||
m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))
|
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
|
||||||
):
|
):
|
||||||
persistent_tma_mm_template.maybe_append_choice(
|
persistent_tma_mm_template.maybe_append_choice(
|
||||||
choices,
|
choices,
|
||||||
|
|
@ -751,52 +756,6 @@ def dims_are_int(dims):
|
||||||
return all(isinstance(dim, int) for dim in dims)
|
return all(isinstance(dim, int) for dim in dims)
|
||||||
|
|
||||||
|
|
||||||
def try_heuristic(m, n, k, choices, mat1, mat2, mat2_dtype, layout):
|
|
||||||
m, n, k = get_size_hints(mat1, mat2, m, n, k)
|
|
||||||
if not dims_are_int([m, n, k]):
|
|
||||||
return None
|
|
||||||
|
|
||||||
if mat1.dtype != torch.float16:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# only use heuristic if we are running on an A100
|
|
||||||
# torch.cuda.get_device_capability() >= (8, 0) returns true for A10G
|
|
||||||
# which does not have enough shared memory for one of the configs
|
|
||||||
if (
|
|
||||||
not torch.cuda.get_device_capability() >= (8, 0)
|
|
||||||
) or get_gpu_shared_memory() != 166912:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if m == 1 and (n % 16 != 0 or k % 16 != 0):
|
|
||||||
return None
|
|
||||||
|
|
||||||
if m <= 16 and n >= 4096 and k >= 4096:
|
|
||||||
return triton_config(
|
|
||||||
BLOCK_M=16,
|
|
||||||
BLOCK_N=64,
|
|
||||||
BLOCK_K=128,
|
|
||||||
num_stages=5,
|
|
||||||
num_warps=4,
|
|
||||||
)
|
|
||||||
elif m > 16 and m <= 32 and n >= 4096 and k >= 4096:
|
|
||||||
return triton_config(
|
|
||||||
BLOCK_M=32,
|
|
||||||
BLOCK_N=32,
|
|
||||||
BLOCK_K=128,
|
|
||||||
num_stages=5,
|
|
||||||
num_warps=4,
|
|
||||||
)
|
|
||||||
elif m > 32 and m <= 64 and n >= 4096 and k >= 4096:
|
|
||||||
return triton_config(
|
|
||||||
BLOCK_M=64,
|
|
||||||
BLOCK_N=32,
|
|
||||||
BLOCK_K=128,
|
|
||||||
num_stages=5,
|
|
||||||
num_warps=4,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def mm_autoheuristic(
|
def mm_autoheuristic(
|
||||||
mat1,
|
mat1,
|
||||||
mat2,
|
mat2,
|
||||||
|
|
|
||||||
|
|
@ -1,441 +1,22 @@
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import functools
|
|
||||||
import itertools
|
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Sequence
|
from typing import Any
|
||||||
from typing import Any, cast
|
|
||||||
|
|
||||||
import sympy
|
import sympy
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._inductor.select_algorithm import realize_inputs, SymbolicGridFn
|
from torch._inductor.select_algorithm import realize_inputs, SymbolicGridFn
|
||||||
from torch._inductor.virtualized import V
|
from torch._inductor.virtualized import V
|
||||||
from torch.utils._ordered_set import OrderedSet
|
|
||||||
|
|
||||||
from .. import config as inductor_config
|
from .. import config as inductor_config
|
||||||
from ..codegen.wrapper import PythonWrapperCodegen
|
from ..codegen.wrapper import PythonWrapperCodegen
|
||||||
from ..ir import ChoiceCaller, Layout
|
from ..ir import ChoiceCaller, Layout
|
||||||
from ..runtime.runtime_utils import next_power_of_2
|
from ..utils import get_num_sms, TMA_DESCRIPTOR_SIZE, use_aten_gemm_kernels
|
||||||
from ..utils import (
|
|
||||||
get_backend_num_stages,
|
|
||||||
get_num_sms,
|
|
||||||
TMA_DESCRIPTOR_SIZE,
|
|
||||||
use_aten_gemm_kernels,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def triton_config(num_stages, num_warps, **kwargs):
|
|
||||||
from triton import Config # type: ignore[attr-defined]
|
|
||||||
|
|
||||||
return Config(kwargs, num_stages=num_stages, num_warps=num_warps)
|
|
||||||
|
|
||||||
|
|
||||||
def build_rocm_gemm_configs(configs):
|
|
||||||
rocm_num_stages = get_backend_num_stages()
|
|
||||||
return tuple((c[0], c[1], c[2], rocm_num_stages, c[4]) for c in configs)
|
|
||||||
|
|
||||||
|
|
||||||
def filtered_configs(
|
|
||||||
m: int,
|
|
||||||
n: int,
|
|
||||||
k: int,
|
|
||||||
configs: Sequence[tuple[int, int, int, int, int]],
|
|
||||||
has_int8_tensor=False,
|
|
||||||
scale=1,
|
|
||||||
exclude=lambda m, n, k: False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Heuristic to shrink configs when they are bigger than the input size
|
|
||||||
|
|
||||||
:param scale: scale factor applied to the config values
|
|
||||||
:param exclude: whether a given config should be excluded
|
|
||||||
"""
|
|
||||||
from torch._inductor import config
|
|
||||||
|
|
||||||
max_mm_configs = config.test_configs.max_mm_configs
|
|
||||||
|
|
||||||
min_block_size = 16
|
|
||||||
# block_k=16 seems to be causing issues
|
|
||||||
# see: https://github.com/triton-lang/triton/issues/2156#issuecomment-1695897424
|
|
||||||
min_block_size_k = 32 if has_int8_tensor else 16
|
|
||||||
m = max(
|
|
||||||
next_power_of_2(
|
|
||||||
V.graph.sizevars.size_hint(
|
|
||||||
m,
|
|
||||||
fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
),
|
|
||||||
min_block_size,
|
|
||||||
)
|
|
||||||
n = max(
|
|
||||||
next_power_of_2(
|
|
||||||
V.graph.sizevars.size_hint(
|
|
||||||
n,
|
|
||||||
fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
),
|
|
||||||
min_block_size,
|
|
||||||
)
|
|
||||||
k = max(
|
|
||||||
next_power_of_2(
|
|
||||||
V.graph.sizevars.size_hint(
|
|
||||||
k,
|
|
||||||
fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
),
|
|
||||||
min_block_size_k,
|
|
||||||
)
|
|
||||||
used = OrderedSet[tuple[int, ...]]()
|
|
||||||
for block_m, block_n, block_k, num_stages, num_warps in configs:
|
|
||||||
# shrink configs for small sizes
|
|
||||||
block_m = max(min(int(block_m * scale), m), min_block_size)
|
|
||||||
block_n = max(min(int(block_n * scale), n), min_block_size)
|
|
||||||
block_k = max(min(int(block_k * scale), k), min_block_size_k)
|
|
||||||
|
|
||||||
if exclude(block_m, block_n, block_k):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# each warp computes 16x16 tile = 256
|
|
||||||
num_warps = min(num_warps, block_m * block_n // 256)
|
|
||||||
if torch.version.hip:
|
|
||||||
kpack = 2
|
|
||||||
for matrix_instr_nonkdim in [0, 16]:
|
|
||||||
if matrix_instr_nonkdim != 0 and (
|
|
||||||
block_m % matrix_instr_nonkdim != 0
|
|
||||||
or block_n % matrix_instr_nonkdim != 0
|
|
||||||
):
|
|
||||||
# block_m and block_n must be a multiple of matrix_instr_nonkdim
|
|
||||||
continue
|
|
||||||
|
|
||||||
if (
|
|
||||||
block_m,
|
|
||||||
block_n,
|
|
||||||
block_k,
|
|
||||||
num_stages,
|
|
||||||
num_warps,
|
|
||||||
matrix_instr_nonkdim,
|
|
||||||
kpack,
|
|
||||||
) not in used and (
|
|
||||||
max_mm_configs is None or len(used) < max_mm_configs
|
|
||||||
):
|
|
||||||
used.add(
|
|
||||||
(
|
|
||||||
block_m,
|
|
||||||
block_n,
|
|
||||||
block_k,
|
|
||||||
num_stages,
|
|
||||||
num_warps,
|
|
||||||
matrix_instr_nonkdim,
|
|
||||||
kpack,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
yield triton_config(
|
|
||||||
BLOCK_M=block_m,
|
|
||||||
BLOCK_N=block_n,
|
|
||||||
BLOCK_K=block_k,
|
|
||||||
num_stages=num_stages,
|
|
||||||
num_warps=num_warps,
|
|
||||||
matrix_instr_nonkdim=matrix_instr_nonkdim,
|
|
||||||
kpack=kpack,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if (block_m, block_n, block_k, num_stages, num_warps, 0) not in used and (
|
|
||||||
max_mm_configs is None or len(used) < max_mm_configs
|
|
||||||
):
|
|
||||||
used.add((block_m, block_n, block_k, num_stages, num_warps, 0))
|
|
||||||
yield triton_config(
|
|
||||||
BLOCK_M=block_m,
|
|
||||||
BLOCK_N=block_n,
|
|
||||||
BLOCK_K=block_k,
|
|
||||||
num_stages=num_stages,
|
|
||||||
num_warps=num_warps,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# List of dictionaries to store the kernel configs. Configs that evaluate to true
|
|
||||||
# will be utilised on the target platform. The configs are as follows:
|
|
||||||
# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
|
|
||||||
mm_kernel_configs = (
|
|
||||||
[
|
|
||||||
{"config": (32, 32, 16, 1, 2), "cond": True},
|
|
||||||
{"config": (32, 32, 128, 2, 4), "cond": True},
|
|
||||||
{"config": (32, 64, 32, 5, 8), "cond": True},
|
|
||||||
{"config": (64, 32, 32, 5, 8), "cond": True},
|
|
||||||
{"config": (64, 32, 128, 5, 4), "cond": True},
|
|
||||||
{"config": (64, 64, 16, 2, 4), "cond": True},
|
|
||||||
{"config": (64, 64, 32, 2, 4), "cond": True},
|
|
||||||
{"config": (64, 64, 64, 3, 8), "cond": True},
|
|
||||||
{"config": (64, 64, 128, 5, 4), "cond": True},
|
|
||||||
{"config": (64, 128, 32, 3, 4), "cond": True},
|
|
||||||
{"config": (64, 128, 32, 4, 8), "cond": True},
|
|
||||||
{"config": (64, 128, 64, 3, 4), "cond": True},
|
|
||||||
{"config": (64, 128, 128, 4, 4), "cond": True},
|
|
||||||
{"config": (128, 64, 32, 3, 4), "cond": True},
|
|
||||||
{"config": (128, 64, 32, 4, 8), "cond": True},
|
|
||||||
{"config": (128, 128, 32, 2, 8), "cond": True},
|
|
||||||
{"config": (128, 128, 32, 3, 4), "cond": True},
|
|
||||||
{"config": (128, 128, 64, 3, 4), "cond": True},
|
|
||||||
{"config": (128, 128, 64, 5, 8), "cond": True},
|
|
||||||
{"config": (128, 256, 64, 3, 8), "cond": True},
|
|
||||||
]
|
|
||||||
if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE"
|
|
||||||
else [
|
|
||||||
{"config": (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps), "cond": True}
|
|
||||||
for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
|
|
||||||
[16, 32, 64, 128, 256], repeat=3
|
|
||||||
)
|
|
||||||
for num_stages in [1, 2, 3, 4, 5]
|
|
||||||
for num_warps in [2, 4, 8]
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# these are only used in tuned_mm when AutoHeuristic is enabled
|
|
||||||
# the idea is that when AutoHeuristic collects data to learn a heuristic, more configs are autotuned
|
|
||||||
# when the learned heuristic is used, the learned heuristic reduces the number of configs down to 10
|
|
||||||
# which saves compilation time (since less configs are autotuned) and potentially increase performance
|
|
||||||
# because the learned heuristic might predict a config that is not part mm_configs
|
|
||||||
extra_mm_kernel_configs = [
|
|
||||||
{"config": (16, 32, 16, 3, 2), "cond": True},
|
|
||||||
{"config": (16, 32, 32, 4, 2), "cond": True},
|
|
||||||
{"config": (16, 32, 32, 5, 2), "cond": True},
|
|
||||||
{"config": (64, 64, 128, 3, 4), "cond": True},
|
|
||||||
{"config": (128, 64, 32, 2, 2), "cond": True},
|
|
||||||
{"config": (128, 64, 64, 3, 8), "cond": True},
|
|
||||||
{"config": (128, 64, 128, 4, 8), "cond": True},
|
|
||||||
{"config": (128, 128, 32, 4, 4), "cond": True},
|
|
||||||
{"config": (128, 128, 64, 3, 8), "cond": True},
|
|
||||||
{"config": (128, 128, 64, 5, 4), "cond": True},
|
|
||||||
]
|
|
||||||
|
|
||||||
int8_mm_kernel_configs = [
|
|
||||||
{"config": (64, 64, 32, 2, 4), "cond": True},
|
|
||||||
{"config": (64, 128, 32, 3, 4), "cond": True},
|
|
||||||
{"config": (128, 64, 32, 3, 4), "cond": True},
|
|
||||||
{"config": (64, 128, 32, 4, 8), "cond": True},
|
|
||||||
{"config": (128, 64, 32, 4, 8), "cond": True},
|
|
||||||
{"config": (64, 32, 32, 5, 8), "cond": True},
|
|
||||||
{"config": (32, 64, 32, 5, 8), "cond": True},
|
|
||||||
{"config": (128, 128, 32, 2, 8), "cond": True},
|
|
||||||
{"config": (64, 64, 64, 3, 8), "cond": True},
|
|
||||||
# {"config": (32, 32, 128, 2, 4), "cond": True},
|
|
||||||
# {"config": (64, 64, 16, 2, 4), "cond": True},
|
|
||||||
# {"config": (32, 32, 16, 1, 2), "cond": True},
|
|
||||||
{"config": (128, 256, 128, 3, 8), "cond": True},
|
|
||||||
{"config": (256, 128, 128, 3, 8), "cond": True},
|
|
||||||
]
|
|
||||||
|
|
||||||
# Mixed precision kernel configs for small sizes of m for mm's like (16, 8192) x (8192, 8192).
|
|
||||||
mixed_mm_kernel_configs_small_m = [
|
|
||||||
{"config": (16, 128, 256, 3, 4), "cond": True},
|
|
||||||
{"config": (16, 128, 256, 5, 8), "cond": True},
|
|
||||||
]
|
|
||||||
|
|
||||||
mixed_mm_kernel_configs = (
|
|
||||||
mm_kernel_configs + mixed_mm_kernel_configs_small_m
|
|
||||||
if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE"
|
|
||||||
else mm_kernel_configs
|
|
||||||
)
|
|
||||||
|
|
||||||
persistent_mm_kernel_configs = [
|
|
||||||
{"config": (128, 256, 64, 3, 8), "cond": True},
|
|
||||||
{"config": (128, 128, 64, 3, 8), "cond": True},
|
|
||||||
{"config": (128, 128, 128, 3, 8), "cond": True},
|
|
||||||
{"config": (128, 128, 128, 3, 4), "cond": True},
|
|
||||||
{"config": (128, 128, 64, 4, 8), "cond": True},
|
|
||||||
]
|
|
||||||
|
|
||||||
scaled_mm_kernel_configs = [
|
|
||||||
{"config": (128, 256, 32, 3, 8), "cond": True},
|
|
||||||
{"config": (256, 128, 32, 3, 8), "cond": True},
|
|
||||||
{"config": (256, 64, 32, 4, 4), "cond": True},
|
|
||||||
{"config": (64, 256, 32, 4, 4), "cond": True},
|
|
||||||
{"config": (128, 128, 32, 4, 4), "cond": True},
|
|
||||||
{"config": (128, 64, 32, 4, 4), "cond": True},
|
|
||||||
{"config": (64, 128, 32, 4, 4), "cond": True},
|
|
||||||
{"config": (128, 32, 32, 4, 4), "cond": True},
|
|
||||||
{"config": (64, 32, 32, 5, 2), "cond": True},
|
|
||||||
{"config": (256, 128, 128, 3, 8), "cond": True},
|
|
||||||
{"config": (256, 64, 128, 4, 4), "cond": True},
|
|
||||||
{"config": (64, 256, 128, 4, 4), "cond": True},
|
|
||||||
{"config": (128, 128, 128, 4, 4), "cond": True},
|
|
||||||
{"config": (128, 64, 64, 4, 4), "cond": True},
|
|
||||||
{"config": (64, 128, 64, 4, 4), "cond": True},
|
|
||||||
{"config": (128, 32, 64, 4, 4), "cond": True},
|
|
||||||
{"config": (64, 32, 64, 5, 2), "cond": True},
|
|
||||||
{"config": (16, 32, 32, 2, 2), "cond": True},
|
|
||||||
{"config": (16, 64, 32, 2, 2), "cond": True},
|
|
||||||
{"config": (16, 128, 32, 2, 4), "cond": True},
|
|
||||||
{"config": (16, 256, 32, 2, 4), "cond": True},
|
|
||||||
{"config": (16, 32, 64, 2, 2), "cond": True},
|
|
||||||
{"config": (16, 64, 64, 2, 2), "cond": True},
|
|
||||||
{"config": (16, 128, 64, 2, 4), "cond": True},
|
|
||||||
{"config": (16, 256, 64, 2, 4), "cond": True},
|
|
||||||
{"config": (32, 32, 32, 2, 2), "cond": True},
|
|
||||||
{"config": (32, 64, 32, 2, 2), "cond": True},
|
|
||||||
{"config": (32, 128, 32, 2, 4), "cond": True},
|
|
||||||
{"config": (32, 256, 32, 2, 4), "cond": True},
|
|
||||||
{"config": (32, 32, 64, 2, 2), "cond": True},
|
|
||||||
{"config": (32, 64, 64, 2, 2), "cond": True},
|
|
||||||
{"config": (32, 128, 64, 2, 4), "cond": True},
|
|
||||||
{"config": (32, 256, 64, 2, 4), "cond": True},
|
|
||||||
{"config": (16, 32, 32, 3, 2), "cond": True},
|
|
||||||
{"config": (16, 64, 32, 3, 2), "cond": True},
|
|
||||||
{"config": (16, 128, 32, 3, 4), "cond": True},
|
|
||||||
{"config": (16, 256, 32, 3, 4), "cond": True},
|
|
||||||
{"config": (16, 32, 64, 3, 2), "cond": True},
|
|
||||||
{"config": (16, 64, 64, 3, 2), "cond": True},
|
|
||||||
{"config": (16, 128, 64, 3, 4), "cond": True},
|
|
||||||
{"config": (16, 256, 64, 3, 4), "cond": True},
|
|
||||||
{"config": (32, 32, 32, 3, 2), "cond": True},
|
|
||||||
{"config": (32, 64, 32, 3, 2), "cond": True},
|
|
||||||
{"config": (32, 128, 32, 3, 4), "cond": True},
|
|
||||||
{"config": (32, 256, 32, 3, 4), "cond": True},
|
|
||||||
{"config": (32, 32, 64, 3, 2), "cond": True},
|
|
||||||
{"config": (32, 64, 64, 3, 2), "cond": True},
|
|
||||||
{"config": (32, 128, 64, 3, 4), "cond": True},
|
|
||||||
{"config": (32, 256, 64, 3, 4), "cond": True},
|
|
||||||
{"config": (16, 32, 32, 4, 2), "cond": True},
|
|
||||||
{"config": (16, 64, 32, 4, 2), "cond": True},
|
|
||||||
{"config": (16, 128, 32, 4, 4), "cond": True},
|
|
||||||
{"config": (16, 256, 32, 4, 4), "cond": True},
|
|
||||||
{"config": (16, 32, 64, 4, 2), "cond": True},
|
|
||||||
{"config": (16, 64, 64, 4, 2), "cond": True},
|
|
||||||
{"config": (16, 128, 64, 4, 4), "cond": True},
|
|
||||||
{"config": (16, 256, 64, 4, 4), "cond": True},
|
|
||||||
{"config": (32, 32, 32, 4, 2), "cond": True},
|
|
||||||
{"config": (32, 64, 32, 4, 2), "cond": True},
|
|
||||||
{"config": (32, 128, 32, 4, 4), "cond": True},
|
|
||||||
{"config": (32, 256, 32, 4, 4), "cond": True},
|
|
||||||
{"config": (32, 32, 64, 4, 2), "cond": True},
|
|
||||||
{"config": (32, 64, 64, 4, 2), "cond": True},
|
|
||||||
{"config": (32, 128, 64, 4, 4), "cond": True},
|
|
||||||
{"config": (32, 256, 64, 4, 4), "cond": True},
|
|
||||||
{"config": (16, 32, 32, 5, 2), "cond": True},
|
|
||||||
{"config": (16, 64, 32, 5, 2), "cond": True},
|
|
||||||
{"config": (16, 128, 32, 5, 4), "cond": True},
|
|
||||||
{"config": (16, 256, 32, 5, 4), "cond": True},
|
|
||||||
{"config": (16, 32, 64, 5, 2), "cond": True},
|
|
||||||
{"config": (16, 64, 64, 5, 2), "cond": True},
|
|
||||||
{"config": (16, 128, 64, 5, 4), "cond": True},
|
|
||||||
{"config": (16, 256, 64, 5, 4), "cond": True},
|
|
||||||
{"config": (32, 32, 32, 5, 2), "cond": True},
|
|
||||||
{"config": (32, 64, 32, 5, 2), "cond": True},
|
|
||||||
{"config": (32, 128, 32, 5, 4), "cond": True},
|
|
||||||
{"config": (32, 256, 32, 5, 4), "cond": True},
|
|
||||||
{"config": (32, 32, 64, 5, 2), "cond": True},
|
|
||||||
{"config": (32, 64, 64, 5, 2), "cond": True},
|
|
||||||
{"config": (32, 128, 64, 5, 4), "cond": True},
|
|
||||||
{"config": (32, 256, 64, 5, 4), "cond": True},
|
|
||||||
{"config": (16, 32, 32, 6, 2), "cond": True},
|
|
||||||
{"config": (16, 64, 32, 6, 2), "cond": True},
|
|
||||||
{"config": (16, 128, 32, 6, 4), "cond": True},
|
|
||||||
{"config": (16, 256, 32, 6, 4), "cond": True},
|
|
||||||
{"config": (16, 32, 64, 6, 2), "cond": True},
|
|
||||||
{"config": (16, 64, 64, 6, 2), "cond": True},
|
|
||||||
{"config": (16, 128, 64, 6, 4), "cond": True},
|
|
||||||
{"config": (16, 256, 64, 6, 4), "cond": True},
|
|
||||||
{"config": (32, 32, 32, 6, 2), "cond": True},
|
|
||||||
{"config": (32, 64, 32, 6, 2), "cond": True},
|
|
||||||
{"config": (32, 128, 32, 6, 4), "cond": True},
|
|
||||||
{"config": (32, 256, 32, 6, 4), "cond": True},
|
|
||||||
{"config": (32, 32, 64, 6, 2), "cond": True},
|
|
||||||
{"config": (32, 64, 64, 6, 2), "cond": True},
|
|
||||||
{"config": (32, 128, 64, 6, 4), "cond": True},
|
|
||||||
{"config": (32, 256, 64, 6, 4), "cond": True},
|
|
||||||
]
|
|
||||||
|
|
||||||
scaled_persistent_mm_kernel_configs = [
|
|
||||||
{"config": (128, 128, 64, 3, 8), "cond": True},
|
|
||||||
{"config": (128, 128, 128, 3, 8), "cond": True},
|
|
||||||
{"config": (128, 128, 128, 4, 8), "cond": True},
|
|
||||||
{"config": (128, 128, 128, 4, 4), "cond": True},
|
|
||||||
{"config": (128, 128, 128, 3, 4), "cond": True},
|
|
||||||
{"config": (128, 128, 128, 5, 4), "cond": True},
|
|
||||||
{"config": (128, 128, 128, 5, 8), "cond": True},
|
|
||||||
{"config": (128, 128, 128, 6, 8), "cond": True},
|
|
||||||
{"config": (128, 128, 64, 4, 8), "cond": True},
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# Create filtered list of configs based on cond evaluation
|
|
||||||
mm_platform_configs = tuple(
|
|
||||||
cast(tuple[int, int, int, int, int], config["config"])
|
|
||||||
for config in mm_kernel_configs
|
|
||||||
if config["cond"]
|
|
||||||
)
|
|
||||||
extra_mm_platform_configs = tuple(
|
|
||||||
cast(tuple[int, int, int, int, int], config["config"])
|
|
||||||
for config in extra_mm_kernel_configs
|
|
||||||
if config["cond"]
|
|
||||||
)
|
|
||||||
int8_platform_configs = tuple(
|
|
||||||
cast(tuple[int, int, int, int, int], config["config"])
|
|
||||||
for config in int8_mm_kernel_configs
|
|
||||||
if config["cond"]
|
|
||||||
)
|
|
||||||
mixed_mm_platform_configs = tuple(
|
|
||||||
cast(tuple[int, int, int, int, int], config["config"])
|
|
||||||
for config in mixed_mm_kernel_configs
|
|
||||||
if config["cond"]
|
|
||||||
)
|
|
||||||
persistent_mm_platform_configs = tuple(
|
|
||||||
cast(tuple[int, int, int, int, int], config["config"])
|
|
||||||
for config in persistent_mm_kernel_configs
|
|
||||||
if config["cond"]
|
|
||||||
)
|
|
||||||
scaled_mm_platform_configs = tuple(
|
|
||||||
cast(tuple[int, int, int, int, int], config["config"])
|
|
||||||
for config in scaled_mm_kernel_configs
|
|
||||||
if config["cond"]
|
|
||||||
)
|
|
||||||
scaled_persistent_mm_platform_configs = tuple(
|
|
||||||
cast(tuple[int, int, int, int, int], config["config"])
|
|
||||||
for config in scaled_persistent_mm_kernel_configs
|
|
||||||
if config["cond"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# On ROCm convert num_stages to improve performance
|
|
||||||
if torch.version.hip and torch.cuda.is_available():
|
|
||||||
mm_platform_configs = build_rocm_gemm_configs(mm_platform_configs)
|
|
||||||
extra_mm_platform_configs = build_rocm_gemm_configs(extra_mm_platform_configs)
|
|
||||||
int8_platform_configs = build_rocm_gemm_configs(int8_platform_configs)
|
|
||||||
mixed_mm_platform_configs = build_rocm_gemm_configs(mixed_mm_platform_configs)
|
|
||||||
scaled_mm_platform_configs = build_rocm_gemm_configs(scaled_mm_platform_configs)
|
|
||||||
|
|
||||||
mm_configs = functools.partial(
|
|
||||||
filtered_configs,
|
|
||||||
configs=mm_platform_configs,
|
|
||||||
)
|
|
||||||
|
|
||||||
extra_mm_configs = functools.partial(
|
|
||||||
filtered_configs,
|
|
||||||
configs=extra_mm_platform_configs,
|
|
||||||
)
|
|
||||||
|
|
||||||
int8_mm_configs = functools.partial(
|
|
||||||
filtered_configs,
|
|
||||||
configs=int8_platform_configs,
|
|
||||||
)
|
|
||||||
|
|
||||||
persistent_mm_configs = functools.partial(
|
|
||||||
filtered_configs,
|
|
||||||
configs=persistent_mm_platform_configs,
|
|
||||||
)
|
|
||||||
|
|
||||||
scaled_mm_configs = functools.partial(
|
|
||||||
filtered_configs,
|
|
||||||
configs=scaled_mm_platform_configs,
|
|
||||||
)
|
|
||||||
|
|
||||||
scaled_persistent_mm_configs = functools.partial(
|
|
||||||
filtered_configs,
|
|
||||||
configs=scaled_persistent_mm_platform_configs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def should_fallback_to_aten(choices: list[ChoiceCaller]) -> bool:
|
def should_fallback_to_aten(choices: list[ChoiceCaller]) -> bool:
|
||||||
if len(choices) == 0 and not use_aten_gemm_kernels():
|
if len(choices) == 0 and not use_aten_gemm_kernels():
|
||||||
if inductor_config.autotune_fallback_to_aten:
|
if inductor_config.autotune_fallback_to_aten:
|
||||||
|
|
@ -553,6 +134,15 @@ def mm_args(
|
||||||
return [m, n, k, layout, mat1, mat2, *others]
|
return [m, n, k, layout, mat1, mat2, *others]
|
||||||
|
|
||||||
|
|
||||||
|
def mm_config_kwargs(device, exclude_condition):
|
||||||
|
if device == "cpu":
|
||||||
|
return {
|
||||||
|
"scale": 0.5,
|
||||||
|
"exclude": exclude_condition,
|
||||||
|
}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def addmm_epilogue(dtype, alpha, beta):
|
def addmm_epilogue(dtype, alpha, beta):
|
||||||
def epilogue(acc, bias):
|
def epilogue(acc, bias):
|
||||||
if alpha != 1:
|
if alpha != 1:
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import functools
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from .. import ir
|
||||||
from ..lowering import lowerings
|
from ..lowering import lowerings
|
||||||
from ..select_algorithm import (
|
from ..select_algorithm import (
|
||||||
autotune_select_algorithm,
|
autotune_select_algorithm,
|
||||||
|
|
@ -112,101 +112,14 @@ mm_plus_mm_template = TritonTemplate(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(None)
|
|
||||||
def mm_configs():
|
|
||||||
import triton
|
|
||||||
|
|
||||||
# List of dictionaries to store the kernel configs. Configs that evaluate to true
|
|
||||||
# will be utilised on the target platform
|
|
||||||
mm_triton_configs = [
|
|
||||||
{
|
|
||||||
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
|
|
||||||
"num_stages": 2,
|
|
||||||
"num_warps": 4,
|
|
||||||
"cond": True,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
|
|
||||||
"num_stages": 3,
|
|
||||||
"num_warps": 8,
|
|
||||||
"cond": True,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
|
|
||||||
"num_stages": 4,
|
|
||||||
"num_warps": 16,
|
|
||||||
"cond": True,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"config": {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32},
|
|
||||||
"num_stages": 4,
|
|
||||||
"num_warps": 8,
|
|
||||||
"cond": True,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"config": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32},
|
|
||||||
"num_stages": 4,
|
|
||||||
"num_warps": 8,
|
|
||||||
"cond": True,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"config": {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32},
|
|
||||||
"num_stages": 1,
|
|
||||||
"num_warps": 8,
|
|
||||||
"cond": True,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64},
|
|
||||||
"num_stages": 1,
|
|
||||||
"num_warps": 8,
|
|
||||||
"cond": True,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"config": {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 128},
|
|
||||||
"num_stages": 1,
|
|
||||||
"num_warps": 8,
|
|
||||||
"cond": torch.version.hip is None,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 16},
|
|
||||||
"num_stages": 2,
|
|
||||||
"num_warps": 4,
|
|
||||||
"cond": True,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"config": {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 16},
|
|
||||||
"num_stages": 1,
|
|
||||||
"num_warps": 2,
|
|
||||||
"cond": True,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
# Filter out configs in which cond evaluates to true
|
|
||||||
# On ROCm convert num_stages to 1 as pipelining provides no benefit
|
|
||||||
if torch.version.hip:
|
|
||||||
filtered_configs = [
|
|
||||||
triton.Config(c["config"], num_stages=1, num_warps=c["num_warps"])
|
|
||||||
for c in mm_triton_configs
|
|
||||||
if c["cond"]
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
filtered_configs = [
|
|
||||||
triton.Config(
|
|
||||||
c["config"], num_stages=c["num_stages"], num_warps=c["num_warps"]
|
|
||||||
)
|
|
||||||
for c in mm_triton_configs
|
|
||||||
if c["cond"]
|
|
||||||
]
|
|
||||||
|
|
||||||
return filtered_configs
|
|
||||||
|
|
||||||
|
|
||||||
def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
|
def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
|
||||||
"""
|
"""
|
||||||
Computes mm(mat1, mat2) + mm(mat3, mat4)
|
Computes mm(mat1, mat2) + mm(mat3, mat4)
|
||||||
"""
|
"""
|
||||||
m1, n1, k1, layout1, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
|
m1, n1, k1, layout1, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
|
||||||
m2, n2, _, layout2, mat3, mat4 = mm_args(mat3, mat4, layout=layout)
|
m2, n2, _, layout2, mat3, mat4 = mm_args(mat3, mat4, layout=layout)
|
||||||
|
device_type = ir.get_device_type(mat1)
|
||||||
|
|
||||||
# Optimization is optional, because we can always just not do the fusion
|
# Optimization is optional, because we can always just not do the fusion
|
||||||
if (
|
if (
|
||||||
m1 * n1 == 0
|
m1 * n1 == 0
|
||||||
|
|
@ -231,6 +144,9 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
|
||||||
if use_aten_gemm_kernels()
|
if use_aten_gemm_kernels()
|
||||||
else []
|
else []
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mm_configs = V.choices.get_mm_plus_mm_configs(device_type)
|
||||||
|
|
||||||
if use_triton_template(layout1):
|
if use_triton_template(layout1):
|
||||||
for config in mm_configs():
|
for config in mm_configs():
|
||||||
# see https://github.com/openai/triton/issues/1298
|
# see https://github.com/openai/triton/issues/1298
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from torch._inductor.codegen.rocm.ck_universal_gemm_template import CKGemmTempla
|
||||||
from torch.utils._triton import has_triton_tma_device
|
from torch.utils._triton import has_triton_tma_device
|
||||||
|
|
||||||
from ..config import triton as triton_config
|
from ..config import triton as triton_config
|
||||||
from ..ir import _IntLike, ChoiceCaller, Layout, StorageBox, TensorBox
|
from ..ir import _IntLike, ChoiceCaller, get_device_type, Layout, StorageBox, TensorBox
|
||||||
from ..lowering import add_layout_constraint, constrain_to_fx_strides, register_lowering
|
from ..lowering import add_layout_constraint, constrain_to_fx_strides, register_lowering
|
||||||
from ..select_algorithm import (
|
from ..select_algorithm import (
|
||||||
autotune_select_algorithm,
|
autotune_select_algorithm,
|
||||||
|
|
@ -27,13 +27,12 @@ from ..utils import (
|
||||||
use_ck_gemm_template,
|
use_ck_gemm_template,
|
||||||
use_triton_template,
|
use_triton_template,
|
||||||
)
|
)
|
||||||
|
from ..virtualized import V
|
||||||
from .mm_common import (
|
from .mm_common import (
|
||||||
_is_static_problem,
|
_is_static_problem,
|
||||||
mm_args,
|
mm_args,
|
||||||
mm_grid,
|
mm_grid,
|
||||||
persistent_mm_grid,
|
persistent_mm_grid,
|
||||||
scaled_mm_configs,
|
|
||||||
scaled_persistent_mm_configs,
|
|
||||||
should_fallback_to_aten,
|
should_fallback_to_aten,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -508,6 +507,7 @@ def tuned_scaled_mm(
|
||||||
m, n, k, layout, mat_a, mat_b = mm_args(
|
m, n, k, layout, mat_a, mat_b = mm_args(
|
||||||
mat_a, mat_b, layout=layout, out_dtype=out_dtype
|
mat_a, mat_b, layout=layout, out_dtype=out_dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
# below is for getting an overview logging info of inductor mms
|
# below is for getting an overview logging info of inductor mms
|
||||||
counters["aten_mm_info"][f"aten._scaled_mm.default_{m}_{n}_{k}"] += 1
|
counters["aten_mm_info"][f"aten._scaled_mm.default_{m}_{n}_{k}"] += 1
|
||||||
log.info(
|
log.info(
|
||||||
|
|
@ -520,6 +520,8 @@ def tuned_scaled_mm(
|
||||||
layout,
|
layout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
device_type = get_device_type(mat_a)
|
||||||
|
|
||||||
check_supported_striding(mat_a, mat_b)
|
check_supported_striding(mat_a, mat_b)
|
||||||
|
|
||||||
scale_a, scale_b = realize_inputs(scale_a, scale_b)
|
scale_a, scale_b = realize_inputs(scale_a, scale_b)
|
||||||
|
|
@ -544,6 +546,11 @@ def tuned_scaled_mm(
|
||||||
|
|
||||||
_, is_nonzero = _is_static_problem(layout)
|
_, is_nonzero = _is_static_problem(layout)
|
||||||
|
|
||||||
|
scaled_mm_configs = V.choices.get_scaled_mm_configs(device_type)
|
||||||
|
scaled_persistent_mm_configs = V.choices.get_scaled_persistent_mm_configs(
|
||||||
|
device_type
|
||||||
|
)
|
||||||
|
|
||||||
if is_nonzero and use_triton_template(layout, enable_float8=True):
|
if is_nonzero and use_triton_template(layout, enable_float8=True):
|
||||||
if use_persistent_tma(k, bias is not None):
|
if use_persistent_tma(k, bias is not None):
|
||||||
for config in scaled_persistent_mm_configs(m, n, k):
|
for config in scaled_persistent_mm_configs(m, n, k):
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,8 @@ def set_driver_to_gpu():
|
||||||
|
|
||||||
|
|
||||||
def get_backend_options():
|
def get_backend_options():
|
||||||
driver = triton.runtime.driver
|
from triton.runtime import driver
|
||||||
|
|
||||||
target = driver.active.get_current_target()
|
target = driver.active.get_current_target()
|
||||||
backend = triton.compiler.compiler.make_backend(target)
|
backend = triton.compiler.compiler.make_backend(target)
|
||||||
options = backend.parse_options(dict())
|
options = backend.parse_options(dict())
|
||||||
|
|
|
||||||
571
torch/_inductor/template_heuristics.py
Normal file
571
torch/_inductor/template_heuristics.py
Normal file
|
|
@ -0,0 +1,571 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
from collections import namedtuple
|
||||||
|
from functools import partial
|
||||||
|
from threading import Lock
|
||||||
|
from typing import Any, Callable, TYPE_CHECKING
|
||||||
|
|
||||||
|
from torch.utils._ordered_set import OrderedSet
|
||||||
|
|
||||||
|
from . import config
|
||||||
|
from .utils import get_backend_num_stages
|
||||||
|
from .virtualized import V
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Generator, Sequence
|
||||||
|
|
||||||
|
from triton import Config as TritonConfig
|
||||||
|
|
||||||
|
|
||||||
|
class BaseConfigSingleton(type):
|
||||||
|
"""
|
||||||
|
Thread-safe implementation of single to be used in the config heuristic subclasses
|
||||||
|
to ensure heavy __init__ calls are not repeatedly run
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instances: dict[type[Any], Any] = {}
|
||||||
|
_lock: Lock = Lock()
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
cls: BaseConfigSingleton, *args: Any, **kwargs: Any
|
||||||
|
) -> BaseConfigHeuristic:
|
||||||
|
with cls._lock:
|
||||||
|
if cls not in cls._instances:
|
||||||
|
instance = super().__call__()
|
||||||
|
cls._instances[cls] = instance
|
||||||
|
return cls._instances[cls]
|
||||||
|
|
||||||
|
|
||||||
|
Config = namedtuple(
|
||||||
|
"Config", ["block_m", "block_n", "block_k", "num_stages", "num_warps"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseConfigHeuristic(metaclass=BaseConfigSingleton):
|
||||||
|
"""
|
||||||
|
Base class for mm_configs, device specific triton kernels config inherit from here
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
# List of dictionaries to store the kernel configs. Configs that evaluate to true
|
||||||
|
# will be utilised on the target platform. The configs are as follows:
|
||||||
|
# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
|
||||||
|
self.mm_configs = [
|
||||||
|
Config(32, 32, 16, 1, 2),
|
||||||
|
Config(32, 32, 128, 2, 4),
|
||||||
|
Config(32, 64, 32, 5, 8),
|
||||||
|
Config(64, 32, 32, 5, 8),
|
||||||
|
Config(64, 32, 128, 5, 4),
|
||||||
|
Config(64, 64, 16, 2, 4),
|
||||||
|
Config(64, 64, 32, 2, 4),
|
||||||
|
Config(64, 64, 64, 3, 8),
|
||||||
|
Config(64, 64, 128, 5, 4),
|
||||||
|
Config(64, 128, 32, 3, 4),
|
||||||
|
Config(64, 128, 32, 4, 8),
|
||||||
|
Config(64, 128, 64, 3, 4),
|
||||||
|
Config(64, 128, 128, 4, 4),
|
||||||
|
Config(128, 64, 32, 3, 4),
|
||||||
|
Config(128, 64, 32, 4, 8),
|
||||||
|
Config(128, 128, 32, 2, 8),
|
||||||
|
Config(128, 128, 32, 3, 4),
|
||||||
|
Config(128, 128, 64, 3, 4),
|
||||||
|
Config(128, 128, 64, 5, 8),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Exhaustive search for mm configs
|
||||||
|
self.exhaustive_configs = [
|
||||||
|
Config(BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
|
||||||
|
for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
|
||||||
|
[16, 32, 64, 128, 256], repeat=3
|
||||||
|
)
|
||||||
|
for num_stages in [1, 2, 3, 4, 5]
|
||||||
|
for num_warps in [2, 4, 8]
|
||||||
|
]
|
||||||
|
|
||||||
|
# these are only used in tuned_mm when AutoHeuristic is enabled
|
||||||
|
# the idea is that when AutoHeuristic collects data to learn a heuristic, more configs are autotuned
|
||||||
|
# when the learned heuristic is used, the learned heuristic reduces the number of configs down to 10
|
||||||
|
# which saves compilation time (since less configs are autotuned) and potentially increase performance
|
||||||
|
# because the learned heuristic might predict a config that is not part mm_configs
|
||||||
|
self.extra_mm_configs = [
|
||||||
|
Config(16, 32, 16, 3, 2),
|
||||||
|
Config(16, 32, 32, 4, 2),
|
||||||
|
Config(16, 32, 32, 5, 2),
|
||||||
|
Config(64, 64, 128, 3, 4),
|
||||||
|
Config(128, 64, 32, 2, 2),
|
||||||
|
Config(128, 64, 64, 3, 8),
|
||||||
|
Config(128, 64, 128, 4, 8),
|
||||||
|
Config(128, 128, 32, 4, 4),
|
||||||
|
Config(128, 128, 64, 3, 8),
|
||||||
|
Config(128, 128, 64, 5, 4),
|
||||||
|
]
|
||||||
|
|
||||||
|
self.int8_mm_configs = [
|
||||||
|
Config(64, 64, 32, 2, 4),
|
||||||
|
Config(64, 128, 32, 3, 4),
|
||||||
|
Config(128, 64, 32, 3, 4),
|
||||||
|
Config(64, 128, 32, 4, 8),
|
||||||
|
Config(128, 64, 32, 4, 8),
|
||||||
|
Config(64, 32, 32, 5, 8),
|
||||||
|
Config(32, 64, 32, 5, 8),
|
||||||
|
Config(128, 128, 32, 2, 8),
|
||||||
|
Config(64, 64, 64, 3, 8),
|
||||||
|
Config(128, 256, 128, 3, 8),
|
||||||
|
Config(256, 128, 128, 3, 8),
|
||||||
|
]
|
||||||
|
|
||||||
|
self.mixed_mm_configs = [
|
||||||
|
Config(16, 128, 256, 3, 4),
|
||||||
|
Config(16, 128, 256, 5, 8),
|
||||||
|
]
|
||||||
|
|
||||||
|
self.persistent_mm_configs = [
|
||||||
|
Config(128, 256, 64, 3, 8),
|
||||||
|
Config(128, 128, 64, 3, 8),
|
||||||
|
Config(128, 128, 128, 3, 8),
|
||||||
|
Config(128, 128, 128, 3, 4),
|
||||||
|
Config(128, 128, 64, 4, 8),
|
||||||
|
]
|
||||||
|
|
||||||
|
self.scaled_mm_configs = [
|
||||||
|
Config(128, 256, 32, 3, 8),
|
||||||
|
Config(256, 128, 32, 3, 8),
|
||||||
|
Config(256, 64, 32, 4, 4),
|
||||||
|
Config(64, 256, 32, 4, 4),
|
||||||
|
Config(128, 128, 32, 4, 4),
|
||||||
|
Config(128, 64, 32, 4, 4),
|
||||||
|
Config(64, 128, 32, 4, 4),
|
||||||
|
Config(128, 32, 32, 4, 4),
|
||||||
|
Config(64, 32, 32, 5, 2),
|
||||||
|
Config(256, 128, 128, 3, 8),
|
||||||
|
Config(256, 64, 128, 4, 4),
|
||||||
|
Config(64, 256, 128, 4, 4),
|
||||||
|
Config(128, 128, 128, 4, 4),
|
||||||
|
Config(128, 64, 64, 4, 4),
|
||||||
|
Config(64, 128, 64, 4, 4),
|
||||||
|
Config(128, 32, 64, 4, 4),
|
||||||
|
Config(64, 32, 64, 5, 2),
|
||||||
|
Config(16, 32, 32, 2, 2),
|
||||||
|
Config(16, 64, 32, 2, 2),
|
||||||
|
Config(16, 128, 32, 2, 4),
|
||||||
|
Config(16, 256, 32, 2, 4),
|
||||||
|
Config(16, 32, 64, 2, 2),
|
||||||
|
Config(16, 64, 64, 2, 2),
|
||||||
|
Config(16, 128, 64, 2, 4),
|
||||||
|
Config(16, 256, 64, 2, 4),
|
||||||
|
Config(32, 32, 32, 2, 2),
|
||||||
|
Config(32, 64, 32, 2, 2),
|
||||||
|
Config(32, 128, 32, 2, 4),
|
||||||
|
Config(32, 256, 32, 2, 4),
|
||||||
|
Config(32, 32, 64, 2, 2),
|
||||||
|
Config(32, 64, 64, 2, 2),
|
||||||
|
Config(32, 128, 64, 2, 4),
|
||||||
|
Config(32, 256, 64, 2, 4),
|
||||||
|
Config(16, 32, 32, 3, 2),
|
||||||
|
Config(16, 64, 32, 3, 2),
|
||||||
|
Config(16, 128, 32, 3, 4),
|
||||||
|
Config(16, 256, 32, 3, 4),
|
||||||
|
Config(16, 32, 64, 3, 2),
|
||||||
|
Config(16, 64, 64, 3, 2),
|
||||||
|
Config(16, 128, 64, 3, 4),
|
||||||
|
Config(16, 256, 64, 3, 4),
|
||||||
|
Config(32, 32, 32, 3, 2),
|
||||||
|
Config(32, 64, 32, 3, 2),
|
||||||
|
Config(32, 128, 32, 3, 4),
|
||||||
|
Config(32, 256, 32, 3, 4),
|
||||||
|
Config(32, 32, 64, 3, 2),
|
||||||
|
Config(32, 64, 64, 3, 2),
|
||||||
|
Config(32, 128, 64, 3, 4),
|
||||||
|
Config(32, 256, 64, 3, 4),
|
||||||
|
Config(16, 32, 32, 4, 2),
|
||||||
|
Config(16, 64, 32, 4, 2),
|
||||||
|
Config(16, 128, 32, 4, 4),
|
||||||
|
Config(16, 256, 32, 4, 4),
|
||||||
|
Config(16, 32, 64, 4, 2),
|
||||||
|
Config(16, 64, 64, 4, 2),
|
||||||
|
Config(16, 128, 64, 4, 4),
|
||||||
|
Config(16, 256, 64, 4, 4),
|
||||||
|
Config(32, 32, 32, 4, 2),
|
||||||
|
Config(32, 64, 32, 4, 2),
|
||||||
|
Config(32, 128, 32, 4, 4),
|
||||||
|
Config(32, 256, 32, 4, 4),
|
||||||
|
Config(32, 32, 64, 4, 2),
|
||||||
|
Config(32, 64, 64, 4, 2),
|
||||||
|
Config(32, 128, 64, 4, 4),
|
||||||
|
Config(32, 256, 64, 4, 4),
|
||||||
|
Config(16, 32, 32, 5, 2),
|
||||||
|
Config(16, 64, 32, 5, 2),
|
||||||
|
Config(16, 128, 32, 5, 4),
|
||||||
|
Config(16, 256, 32, 5, 4),
|
||||||
|
Config(16, 32, 64, 5, 2),
|
||||||
|
Config(16, 64, 64, 5, 2),
|
||||||
|
Config(16, 128, 64, 5, 4),
|
||||||
|
Config(16, 256, 64, 5, 4),
|
||||||
|
Config(32, 32, 32, 5, 2),
|
||||||
|
Config(32, 64, 32, 5, 2),
|
||||||
|
Config(32, 128, 32, 5, 4),
|
||||||
|
Config(32, 256, 32, 5, 4),
|
||||||
|
Config(32, 32, 64, 5, 2),
|
||||||
|
Config(32, 64, 64, 5, 2),
|
||||||
|
Config(32, 128, 64, 5, 4),
|
||||||
|
Config(32, 256, 64, 5, 4),
|
||||||
|
Config(16, 32, 32, 6, 2),
|
||||||
|
Config(16, 64, 32, 6, 2),
|
||||||
|
Config(16, 128, 32, 6, 4),
|
||||||
|
Config(16, 256, 32, 6, 4),
|
||||||
|
Config(16, 32, 64, 6, 2),
|
||||||
|
Config(16, 64, 64, 6, 2),
|
||||||
|
Config(16, 128, 64, 6, 4),
|
||||||
|
Config(16, 256, 64, 6, 4),
|
||||||
|
Config(32, 32, 32, 6, 2),
|
||||||
|
Config(32, 64, 32, 6, 2),
|
||||||
|
Config(32, 128, 32, 6, 4),
|
||||||
|
Config(32, 256, 32, 6, 4),
|
||||||
|
Config(32, 32, 64, 6, 2),
|
||||||
|
Config(32, 64, 64, 6, 2),
|
||||||
|
Config(32, 128, 64, 6, 4),
|
||||||
|
Config(32, 256, 64, 6, 4),
|
||||||
|
]
|
||||||
|
|
||||||
|
self.scaled_persistent_mm_configs = [
|
||||||
|
Config(128, 128, 64, 3, 8),
|
||||||
|
Config(128, 128, 128, 3, 8),
|
||||||
|
Config(128, 128, 128, 4, 8),
|
||||||
|
Config(128, 128, 128, 4, 4),
|
||||||
|
Config(128, 128, 128, 3, 4),
|
||||||
|
Config(128, 128, 128, 5, 4),
|
||||||
|
Config(128, 128, 128, 5, 8),
|
||||||
|
Config(128, 128, 128, 6, 8),
|
||||||
|
Config(128, 128, 64, 4, 8),
|
||||||
|
]
|
||||||
|
|
||||||
|
# TODO: Unify with other gemm patterns, mm_plus_mm currently follows
|
||||||
|
# slightly different pattern than rest
|
||||||
|
self.mm_plus_mm_configs = [
|
||||||
|
Config(64, 64, 32, 2, 4),
|
||||||
|
Config(64, 64, 32, 3, 8),
|
||||||
|
Config(64, 64, 32, 4, 16),
|
||||||
|
Config(64, 32, 32, 4, 8),
|
||||||
|
Config(32, 64, 32, 4, 8),
|
||||||
|
Config(128, 128, 32, 1, 8),
|
||||||
|
Config(64, 64, 64, 1, 8),
|
||||||
|
Config(32, 32, 128, 1, 8),
|
||||||
|
Config(64, 64, 16, 2, 4),
|
||||||
|
Config(32, 32, 16, 1, 2),
|
||||||
|
]
|
||||||
|
|
||||||
|
self.conv_configs = [
|
||||||
|
Config(64, 256, 16, 2, 4),
|
||||||
|
Config(256, 64, 16, 2, 4),
|
||||||
|
Config(1024, 16, 16, 1, 8),
|
||||||
|
Config(128, 128, 32, 2, 8),
|
||||||
|
Config(64, 64, 32, 2, 4),
|
||||||
|
Config(64, 256, 32, 2, 8),
|
||||||
|
Config(256, 64, 32, 2, 8),
|
||||||
|
]
|
||||||
|
|
||||||
|
def _finalize_mm_configs(
|
||||||
|
self,
|
||||||
|
configs: list[Config],
|
||||||
|
) -> Generator[TritonConfig, None, None]:
|
||||||
|
"""
|
||||||
|
Finalizes configs after scaling, applying additional constraints.
|
||||||
|
"""
|
||||||
|
used = OrderedSet[Config]()
|
||||||
|
|
||||||
|
max_mm_configs = config.test_configs.max_mm_configs
|
||||||
|
|
||||||
|
for block_m, block_n, block_k, num_stages, num_warps in configs:
|
||||||
|
# Each warp computes a 16x16 tile = 256 elements
|
||||||
|
num_warps = min(num_warps, block_m * block_n // 256)
|
||||||
|
|
||||||
|
if (
|
||||||
|
Config(block_m, block_n, block_k, num_stages, num_warps)
|
||||||
|
) not in used and (max_mm_configs is None or len(used) < max_mm_configs):
|
||||||
|
used.add(Config(block_m, block_n, block_k, num_stages, num_warps))
|
||||||
|
yield self.triton_config(
|
||||||
|
BLOCK_M=block_m,
|
||||||
|
BLOCK_N=block_n,
|
||||||
|
BLOCK_K=block_k,
|
||||||
|
num_stages=num_stages,
|
||||||
|
num_warps=num_warps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _scale_mm_configs(
|
||||||
|
self,
|
||||||
|
m: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
configs: Sequence[Config],
|
||||||
|
scale: float,
|
||||||
|
has_int8_tensor: bool,
|
||||||
|
exclude: Callable[[int, int, int], bool],
|
||||||
|
) -> list[Config]:
|
||||||
|
"""
|
||||||
|
Scales and filters matrix multiplication configs based on input size.
|
||||||
|
"""
|
||||||
|
from .runtime.runtime_utils import next_power_of_2
|
||||||
|
|
||||||
|
min_block_size = 16
|
||||||
|
min_block_size_k = 32 if has_int8_tensor else 16
|
||||||
|
|
||||||
|
m = max(
|
||||||
|
next_power_of_2(
|
||||||
|
V.graph.sizevars.size_hint(
|
||||||
|
m,
|
||||||
|
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
min_block_size,
|
||||||
|
)
|
||||||
|
n = max(
|
||||||
|
next_power_of_2(
|
||||||
|
V.graph.sizevars.size_hint(
|
||||||
|
n,
|
||||||
|
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
min_block_size,
|
||||||
|
)
|
||||||
|
k = max(
|
||||||
|
next_power_of_2(
|
||||||
|
V.graph.sizevars.size_hint(
|
||||||
|
k,
|
||||||
|
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
min_block_size_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
scaled_configs = []
|
||||||
|
for c in configs:
|
||||||
|
scaled_config = c._replace(
|
||||||
|
block_m=max(min(int(c.block_m * scale), m), min_block_size),
|
||||||
|
block_n=max(min(int(c.block_n * scale), n), min_block_size),
|
||||||
|
block_k=max(min(int(c.block_k * scale), k), min_block_size_k),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not exclude(
|
||||||
|
scaled_config.block_m, scaled_config.block_n, scaled_config.block_k
|
||||||
|
):
|
||||||
|
scaled_configs.append(scaled_config)
|
||||||
|
|
||||||
|
return scaled_configs
|
||||||
|
|
||||||
|
def preprocess_mm_configs(
|
||||||
|
self,
|
||||||
|
m: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
configs: Sequence[Config],
|
||||||
|
has_int8_tensor: bool = False,
|
||||||
|
scale: int = 1,
|
||||||
|
exclude: Callable[[int, int, int], bool] = lambda m, n, k: False,
|
||||||
|
) -> Generator[TritonConfig, None, None]:
|
||||||
|
scaled_configs = self._scale_mm_configs(
|
||||||
|
m, n, k, configs, scale, has_int8_tensor, exclude
|
||||||
|
)
|
||||||
|
return self._finalize_mm_configs(scaled_configs)
|
||||||
|
|
||||||
|
def triton_config(
|
||||||
|
self, num_stages: int, num_warps: int, **kwargs: Any
|
||||||
|
) -> TritonConfig:
|
||||||
|
from triton import Config as TritonConfig # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
return TritonConfig(kwargs, num_stages=num_stages, num_warps=num_warps)
|
||||||
|
|
||||||
|
def get_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||||
|
return partial(self.preprocess_mm_configs, configs=self.mm_configs)
|
||||||
|
|
||||||
|
def get_exhaustive_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||||
|
return partial(self.preprocess_mm_configs, configs=self.exhaustive_configs)
|
||||||
|
|
||||||
|
def get_extra_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||||
|
return partial(self.preprocess_mm_configs, configs=self.extra_mm_configs)
|
||||||
|
|
||||||
|
def get_int8_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||||
|
return partial(self.preprocess_mm_configs, configs=self.int8_mm_configs)
|
||||||
|
|
||||||
|
def get_mixed_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||||
|
mm_configs = (
|
||||||
|
self.mm_configs + self.mixed_mm_configs
|
||||||
|
if config.max_autotune_gemm_search_space == "EXHAUSTIVE"
|
||||||
|
else self.mm_configs
|
||||||
|
)
|
||||||
|
return partial(self.preprocess_mm_configs, configs=mm_configs)
|
||||||
|
|
||||||
|
def get_persistent_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||||
|
return partial(self.preprocess_mm_configs, configs=self.persistent_mm_configs)
|
||||||
|
|
||||||
|
def get_scaled_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||||
|
return partial(self.preprocess_mm_configs, configs=self.scaled_mm_configs)
|
||||||
|
|
||||||
|
def get_scaled_persistent_mm_configs(
|
||||||
|
self,
|
||||||
|
) -> partial[Generator[TritonConfig, None, None]]:
|
||||||
|
return partial(
|
||||||
|
self.preprocess_mm_configs, configs=self.scaled_persistent_mm_configs
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_mm_plus_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||||
|
return partial(self._finalize_mm_configs, configs=self.mm_plus_mm_configs)
|
||||||
|
|
||||||
|
def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||||
|
return partial(self.preprocess_mm_configs, configs=self.conv_configs)
|
||||||
|
|
||||||
|
|
||||||
|
class CPUConfigHeuristic(BaseConfigHeuristic):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CUDAConfigHeuristic(BaseConfigHeuristic):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ROCmConfigHeuristic(BaseConfigHeuristic):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.default_num_stages = get_backend_num_stages()
|
||||||
|
|
||||||
|
# Exhaustive search for mm configs
|
||||||
|
self.exhaustive_configs = [
|
||||||
|
Config(BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
|
||||||
|
for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
|
||||||
|
[16, 32, 64, 128, 256], repeat=3
|
||||||
|
)
|
||||||
|
for num_stages in [1, self.default_num_stages]
|
||||||
|
for num_warps in [4, 8]
|
||||||
|
]
|
||||||
|
|
||||||
|
def _filter_configs(
|
||||||
|
self, configs: list[Config], new_num_stages: int
|
||||||
|
) -> list[Config]:
|
||||||
|
filtered_configs = [
|
||||||
|
c._replace(num_stages=self.default_num_stages) for c in configs
|
||||||
|
]
|
||||||
|
return filtered_configs
|
||||||
|
|
||||||
|
def _finalize_mm_configs(
|
||||||
|
self,
|
||||||
|
configs: list[Config],
|
||||||
|
) -> Generator[TritonConfig, None, None]:
|
||||||
|
used = OrderedSet[tuple[Config, int, int]]()
|
||||||
|
|
||||||
|
max_mm_configs = config.test_configs.max_mm_configs
|
||||||
|
for block_m, block_n, block_k, num_stages, num_warps in configs:
|
||||||
|
# each warp computes 16x16 tile = 256
|
||||||
|
num_warps = min(num_warps, block_m * block_n // 256)
|
||||||
|
kpack = 2
|
||||||
|
for matrix_instr_nonkdim in [0, 16]:
|
||||||
|
if matrix_instr_nonkdim != 0 and (
|
||||||
|
block_m % matrix_instr_nonkdim != 0
|
||||||
|
or block_n % matrix_instr_nonkdim != 0
|
||||||
|
):
|
||||||
|
# block_m and block_n must be a multiple of matrix_instr_nonkdim
|
||||||
|
continue
|
||||||
|
if (
|
||||||
|
Config(
|
||||||
|
block_m,
|
||||||
|
block_n,
|
||||||
|
block_k,
|
||||||
|
num_stages,
|
||||||
|
num_warps,
|
||||||
|
),
|
||||||
|
matrix_instr_nonkdim,
|
||||||
|
kpack,
|
||||||
|
) not in used and (
|
||||||
|
max_mm_configs is None or len(used) < max_mm_configs
|
||||||
|
):
|
||||||
|
used.add(
|
||||||
|
(
|
||||||
|
Config(
|
||||||
|
block_m,
|
||||||
|
block_n,
|
||||||
|
block_k,
|
||||||
|
num_stages,
|
||||||
|
num_warps,
|
||||||
|
),
|
||||||
|
matrix_instr_nonkdim,
|
||||||
|
kpack,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self.triton_config(
|
||||||
|
BLOCK_M=block_m,
|
||||||
|
BLOCK_N=block_n,
|
||||||
|
BLOCK_K=block_k,
|
||||||
|
num_stages=num_stages,
|
||||||
|
num_warps=num_warps,
|
||||||
|
matrix_instr_nonkdim=matrix_instr_nonkdim,
|
||||||
|
kpack=kpack,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||||
|
filtered_configs = self._filter_configs(
|
||||||
|
self.mm_configs, self.default_num_stages
|
||||||
|
)
|
||||||
|
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
||||||
|
|
||||||
|
def get_exhaustive_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||||
|
filtered_configs = self._filter_configs(
|
||||||
|
self.exhaustive_configs, self.default_num_stages
|
||||||
|
)
|
||||||
|
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
||||||
|
|
||||||
|
def get_extra_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||||
|
filtered_configs = self._filter_configs(
|
||||||
|
self.extra_mm_configs, self.default_num_stages
|
||||||
|
)
|
||||||
|
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
||||||
|
|
||||||
|
def get_int8_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||||
|
filtered_configs = self._filter_configs(
|
||||||
|
self.int8_mm_configs, self.default_num_stages
|
||||||
|
)
|
||||||
|
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
||||||
|
|
||||||
|
def get_mixed_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||||
|
mm_configs = (
|
||||||
|
self.mm_configs + self.mixed_mm_configs
|
||||||
|
if config.max_autotune_gemm_search_space == "EXHAUSTIVE"
|
||||||
|
else self.mm_configs
|
||||||
|
)
|
||||||
|
filtered_configs = self._filter_configs(mm_configs, self.default_num_stages)
|
||||||
|
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
||||||
|
|
||||||
|
def get_persistent_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||||
|
filtered_configs = self._filter_configs(
|
||||||
|
self.persistent_mm_configs, self.default_num_stages
|
||||||
|
)
|
||||||
|
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
||||||
|
|
||||||
|
def get_scaled_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||||
|
filtered_configs = self._filter_configs(
|
||||||
|
self.scaled_mm_configs, self.default_num_stages
|
||||||
|
)
|
||||||
|
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
||||||
|
|
||||||
|
def get_scaled_persistent_mm_configs(
|
||||||
|
self,
|
||||||
|
) -> partial[Generator[TritonConfig, None, None]]:
|
||||||
|
filtered_configs = self._filter_configs(
|
||||||
|
self.scaled_persistent_mm_configs, self.default_num_stages
|
||||||
|
)
|
||||||
|
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
||||||
|
|
||||||
|
def get_mm_plus_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||||
|
filtered_configs = self._filter_configs(self.mm_plus_mm_configs, 1)
|
||||||
|
return partial(self._finalize_mm_configs, configs=filtered_configs)
|
||||||
|
|
||||||
|
def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
||||||
|
filtered_configs = self._filter_configs(
|
||||||
|
self.conv_configs, self.default_num_stages
|
||||||
|
)
|
||||||
|
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
||||||
|
|
||||||
|
|
||||||
|
class XPUConfigHeuristic(BaseConfigHeuristic):
|
||||||
|
pass
|
||||||
Loading…
Reference in New Issue
Block a user