[Inductor] Reland Merge Triton ScaledMM as epilogue to MM template #150045 (#150441)

Merges https://github.com/pytorch/pytorch/pull/150438 and https://github.com/pytorch/pytorch/pull/150045. https://github.com/pytorch/pytorch/pull/150045 was already landed, but did not include a change that makes it unable to land internally.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150441
Approved by: https://github.com/clee2000
This commit is contained in:
PaulZhang12 2025-04-02 17:49:32 +00:00 committed by PyTorch MergeBot
parent 238109ad32
commit e62d958f02
4 changed files with 468 additions and 623 deletions

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import functools
import logging
from typing import Optional
from typing import Any, Optional
import torch
from torch._dynamo.utils import counters
@ -21,10 +21,16 @@ from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTem
from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
from ..codegen.wrapper import PythonWrapperCodegen
from ..ir import FlexibleLayout, is_triton
from ..lowering import register_lowering
from ..lowering import (
add_layout_constraint,
constrain_to_fx_strides,
lowerings as L,
register_lowering,
)
from ..select_algorithm import (
autotune_select_algorithm,
ExternKernelChoice,
realize_inputs,
TritonTemplate,
)
from ..utils import (
@ -46,6 +52,8 @@ from .mm_common import (
mm_options,
persistent_mm_grid,
persistent_mm_options,
scale_mm_epilogue,
scaled_mm_options,
should_fallback_to_aten,
)
@ -121,7 +129,12 @@ mm_template = TritonTemplate(
idx_m = b_k_idx_vals
idx_n = offs_b_n[None, :]
{{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}}
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
{% if USE_FAST_ACCUM %}
acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE)
{% else %}
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE)
{% endif %}
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
@ -192,7 +205,11 @@ mm_template = TritonTemplate(
idx_m = b_k_idx_vals
idx_n = offs_b_n[None, :]
{{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}}
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
{% if USE_FAST_ACCUM %}
acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE)
{% else %}
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE)
{% endif %}
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
@ -295,20 +312,195 @@ persistent_tma_mm_template = TritonTemplate(
allow_tf32=ALLOW_TF32,
)
if ki == k_tiles - 1:
# rematerialize rm and rn to save registers
rcm = rm + tl.arange(0, BLOCK_M)
rcn = rn + tl.arange(0, BLOCK_N)
idx_m = rcm[:, None]
idx_n = rcn[None, :]
mask = (idx_m < M) & (idx_n < N)
{% if ki == k_tiles - 1 %}
# rematerialize rm and rn to save registers
rcm = rm + tl.arange(0, BLOCK_M)
rcn = rn + tl.arange(0, BLOCK_N)
idx_m = rcm[:, None]
idx_n = rcn[None, :]
mask = (idx_m < M) & (idx_n < N)
# inductor generates a suffix
{{store_output(("idx_m", "idx_n"), "acc", "mask", indent_width=12)}}
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
# inductor generates a suffix
{{store_output(("idx_m", "idx_n"), "acc", "mask", indent_width=12)}}
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
{% endif %}
""",
)
load_scales = r"""
@triton.jit
def load_scales(a_scale_ptr, b_scale_ptr, SCALING_ROWWISE: tl.constexpr):
if SCALING_ROWWISE:
# For row-wise scaling, we'll return the pointers
return a_scale_ptr, b_scale_ptr
else:
# For per-tensor scaling, we'll load the scalar values
a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr)
return a_scale, b_scale
"""
apply_scaling = r"""
@triton.jit
def apply_scaling(
accumulator,
a_scale,
b_scale,
SCALING_ROWWISE: tl.constexpr,
offs_cm,
offs_cn,
M,
N,
stride_a_scale_m,
stride_b_scale_n,
):
if SCALING_ROWWISE:
# For row-wise scaling, we need to load the scales for each row/column
a_scales = tl.load(
a_scale + (offs_cm * stride_a_scale_m),
mask=offs_cm < M,
other=0.0,
)
b_scales = tl.load(
b_scale + (offs_cn * stride_b_scale_n),
mask=offs_cn < N,
other=0.0,
)
acc_scale = a_scales[:, None] * b_scales[None, :]
else:
# For per-tensor scaling, we can directly use the loaded scalar values
acc_scale = a_scale * b_scale
return accumulator * acc_scale
"""
device_tma = r"""
{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}}
M = {{size("A", 0)}}
N = {{size("B", 1)}}
K = {{size("A", 1)}}
if M * N == 0:
# early exit due to zero-size input(s)
return
stride_am = {{stride("A", 0)}}
stride_ak = {{stride("A", 1)}}
stride_bk = {{stride("B", 0)}}
stride_bn = {{stride("B", 1)}}
if SCALING_ROWWISE:
stride_a_scale_m = 1
stride_b_scale_n = 1
else:
stride_a_scale_m = 0
stride_b_scale_n = 0
start_pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
k_tiles = tl.cdiv(K, BLOCK_K)
num_tiles = num_pid_m * num_pid_n
workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE
a_desc_ptr = workspace_base
b_desc_ptr = workspace_base + TMA_SIZE
triton.language.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=a_desc_ptr,
global_address=A,
load_size=[BLOCK_M, BLOCK_K],
global_size=[M, K],
element_ty=A.dtype.element_ty,
)
triton.language.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=b_desc_ptr,
global_address=B,
load_size=[BLOCK_N, BLOCK_K],
global_size=[N, K],
element_ty=B.dtype.element_ty,
)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
tiles_per_SM = num_tiles // NUM_SMS
if start_pid < num_tiles % NUM_SMS:
tiles_per_SM += 1
tile_id = start_pid - NUM_SMS
ki = -1
pid_m = 0
pid_n = 0
offs_am = 0
offs_bn = 0
num_pid_in_group = GROUP_M * num_pid_n
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
a_scale, b_scale = load_scales(A_inverse_scale, B_inverse_scale, SCALING_ROWWISE)
for _ in range(0, k_tiles * tiles_per_SM):
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
if ki == 0:
tile_id += NUM_SMS
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
pid_m = first_pid_m + (tile_id % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_M
offs_bn = pid_n * BLOCK_N
offs_k = ki * BLOCK_K
a = tl._experimental_descriptor_load(
a_desc_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], A.dtype.element_ty
)
b = tl._experimental_descriptor_load(
b_desc_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], B.dtype.element_ty
)
if USE_FAST_ACCUM:
accumulator = tl.dot(a, b.T, accumulator)
else:
accumulator += tl.dot(a, b.T)
{% if ki == k_tiles - 1 %}
# Apply inverse scaling
offs_cm = offs_am + tl.arange(0, BLOCK_M)
offs_cn = offs_bn + tl.arange(0, BLOCK_N)
# Apply scaling
accumulator = apply_scaling(
accumulator,
a_scale,
b_scale,
SCALING_ROWWISE,
offs_cm,
offs_cn,
M,
N,
stride_a_scale_m,
stride_b_scale_n,
)
idx_m = offs_cm[:, None]
idx_n = offs_cn[None, :]
mask = (idx_m < M) & (idx_n < N)
# inductor generates a suffix
{{store_output(("idx_m", "idx_n"), "accumulator", "mask", indent_width=12)}}
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
{% endif %}
"""
scaled_mm_device_tma_template = TritonTemplate(
name="scaled_mm_device_tma",
grid=persistent_mm_grid,
source=device_tma + load_scales + apply_scaling,
)
# prevent duplication registration of extern functions
@functools.lru_cache(None)
@ -330,6 +522,10 @@ aten__sparse_semi_structured_mm = ExternKernelChoice(
has_out_variant=False,
)
aten__fp8_mm = ExternKernelChoice(
torch._scaled_mm, "at::_scaled_mm_out", op_overload=aten._scaled_mm.out
)
def _is_int8_mat(mat):
return mat.get_dtype() in (torch.int8, torch.uint8)
@ -340,6 +536,16 @@ def _is_large_block_for_cpu(m, n, k):
return m * n > 2**13
@functools.lru_cache
def using_b200() -> bool:
"""Returns true if the device is a NVIDIA B200, otherwise returns false."""
if not torch.cuda.is_available():
return False
# compute capability 10.0 or 10.0a is NVIDIA B200
device_properties = torch.cuda.get_device_properties(torch.cuda.current_device())
return device_properties.major == 10
def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1):
"""
Giving torch.addmm a 1D tensor calls a different (faster) cublasLt
@ -351,6 +557,32 @@ def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1):
return torch.addmm(inp, mat1, mat2, out=out, alpha=alpha, beta=beta)
def check_supported_striding(mat_a, mat_b) -> None:
def is_row_major(stride) -> bool:
return V.graph.sizevars.statically_known_equals(stride[1], 1)
def is_col_major(stride) -> bool:
return V.graph.sizevars.statically_known_equals(stride[0], 1)
def has_zero_dim(size) -> bool:
return bool(
V.graph.sizevars.statically_known_equals(size[0], 0)
or V.graph.sizevars.statically_known_equals(size[1], 0)
)
# Check mat_a (self) stride requirements
torch._check(
is_row_major(mat_a.get_stride()) or has_zero_dim(mat_a.get_size()),
lambda: f"mat_a must be row_major, got stride {mat_a.get_stride()}",
)
# Check mat_b stride requirements
torch._check(
is_col_major(mat_b.get_stride()) or has_zero_dim(mat_b.get_size()),
lambda: f"mat_b must be col_major, got stride {mat_b.get_stride()}",
)
aten_bias_addmm = ExternKernelChoice(bias_addmm, None)
@ -750,6 +982,151 @@ def tuned_sparse_semi_structured_mm(
)
add_layout_constraint(aten._scaled_mm.default, constrain_to_fx_strides)
@register_lowering(aten._scaled_mm.default, type_promotion_kind=None) # type: ignore[misc]
def tuned_scaled_mm(
mat_a,
mat_b,
scale_a,
scale_b,
bias=None,
scale_result=None,
out_dtype=None,
use_fast_accum=False,
layout=None,
):
m, n, k, layout, mat_a, mat_b = mm_args(
mat_a, mat_b, layout=layout, out_dtype=out_dtype
)
# below is for getting an overview logging info of inductor mms
counters["aten_mm_info"][f"aten._scaled_mm.default_{m}_{n}_{k}"] += 1
log.info(
"Tuned aten._scaled_mm.default: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s",
m,
n,
k,
mat_a.get_dtype(),
mat_b.get_dtype(),
layout,
)
device_type = ir.get_device_type(mat_a)
check_supported_striding(mat_a, mat_b)
scale_a_real, scale_b_real = realize_inputs(scale_a, scale_b)
input_nodes: tuple[Any, ...]
if not bias:
input_nodes = (mat_a, mat_b, scale_a_real, scale_b_real)
else:
bias_real = realize_inputs(bias)
input_nodes = (mat_a, mat_b, scale_a_real, scale_b_real, bias_real)
aten_choice = aten__fp8_mm.bind(
input_nodes, layout, out_dtype=out_dtype, use_fast_accum=use_fast_accum
)
choices = []
if use_aten_gemm_kernels():
choices.append(aten_choice)
_, is_nonzero = _is_static_problem(layout)
scaled_mm_configs = V.choices.get_scaled_mm_configs(device_type)
scaled_persistent_mm_configs = V.choices.get_scaled_persistent_mm_configs(
device_type
)
if is_nonzero and use_triton_template(layout, enable_float8=True):
triton_input_nodes: tuple[Any, ...]
if bias and len(mat_b.get_size()) == len(bias.get_size()) + 1:
# Need to unsqueeze bias from [N] -> [1, N]
triton_bias = L[aten.unsqueeze](bias, 0)
else:
triton_bias = bias
if len(scale_a.get_size()) == 0 or len(scale_b.get_size()) == 0:
assert len(scale_a.get_size()) == len(scale_b.get_size())
# Need to unsqueeze scale from [] -> [1, 1]
triton_scale_a = L[aten.unsqueeze](L[aten.unsqueeze](scale_a, 0), 1)
triton_scale_b = L[aten.unsqueeze](L[aten.unsqueeze](scale_b, 0), 1)
else:
triton_scale_a = scale_a
triton_scale_b = scale_b
if bias:
triton_input_nodes = (
mat_a,
mat_b,
triton_scale_a,
triton_scale_b,
triton_bias,
)
suffix_args = 3
else:
triton_input_nodes = (mat_a, mat_b, triton_scale_a, triton_scale_b)
suffix_args = 2
# TODO (paulzhan): There is no template that exists for bias and TMA
# Don't run tma template currently if bias exists
if use_triton_tma_template(mat_a, mat_b) and not bias:
for config in scaled_persistent_mm_configs(m, n, k):
kwargs = scaled_mm_options(
config,
m,
n,
k,
layout,
scale_a,
scale_b,
use_fast_accum,
device_tma=True,
)
scaled_mm_device_tma_template.maybe_append_choice(
choices,
input_nodes=triton_input_nodes,
layout=layout,
workspace_arg=get_tma_workspace_arg(
num_tma_descriptors=2,
device=mat_a.get_device(),
),
**kwargs,
)
for config in scaled_mm_configs(m, n, k):
if k == 16 and config.kwargs["BLOCK_M"] >= 64:
continue # Triton crashes in this case
# On NVIDIA B200 GPUs, K dim must be >= 32 for tcgen05.mma.kind::f8f6f4.* PTX instruction to be valid
# source: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-shape
if using_b200() and k < 32:
continue
kwargs = scaled_mm_options(
config, m, n, k, layout, scale_a, scale_b, use_fast_accum
)
# possibly appends a TritonTemplateCaller to choices
mm_template.maybe_append_choice(
choices,
input_nodes=triton_input_nodes,
layout=layout,
**kwargs,
suffix_args=suffix_args,
epilogue_fn=scale_mm_epilogue(),
)
if is_nonzero and use_ck_gemm_template(layout, m, n, k):
CKGemmTemplate.add_ck_gemm_choices(choices, layout, input_nodes)
if should_fallback_to_aten(choices):
return aten_choice.output_node()
return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout)
@functools.lru_cache(None)
def _is_sm7x_or_older_gpu(index: Optional[int]) -> bool:
props = torch.cuda.get_device_properties(index or 0)

View File

@ -76,6 +76,7 @@ def mm_options(config, sym_m, sym_n, sym_k, layout):
GROUP_M=8,
EVEN_K=even_k_symbolic,
ALLOW_TF32=allow_tf32,
USE_FAST_ACCUM=False, # Option for _scaled_mm
ACC_TYPE=acc_type(layout.dtype),
num_stages=config.num_stages,
num_warps=config.num_warps,
@ -92,6 +93,47 @@ def persistent_mm_options(mat1, mat2):
)
def scaled_mm_options( # type: ignore[no-untyped-def]
config, # triton.Config
sym_m: sympy.core.numbers.Integer,
sym_n: sympy.core.numbers.Integer,
sym_k: sympy.core.numbers.Integer,
layout: Layout,
scale_a,
scale_b,
use_fast_accum: bool,
device_tma: bool = False,
) -> dict[str, Any]:
def are_compatible_scales(size_a, size_b) -> bool:
# Same sized scales are compatable
if len(size_a) == len(size_b):
return True
# Both need to be scalars or len(1) tensors
if len(size_a) <= 1 and len(size_b) <= 1:
return True
return False
size_a, size_b = scale_a.get_size(), scale_b.get_size()
assert are_compatible_scales(size_a, size_b), (
"Expect scale_a and scale_b to be either both scalars (including single-element tensors) "
f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}."
)
mm_template_options = mm_options(config, sym_m, sym_n, sym_k, layout)
mm_template_options["ACC_TYPE"] = "tl.float32"
mm_template_options["USE_FAST_ACCUM"] = use_fast_accum
mm_template_options["SCALING_ROWWISE"] = len(size_a) == 2
if device_tma:
mm_template_options["TMA_SIZE"] = TMA_DESCRIPTOR_SIZE
mm_template_options["NUM_SMS"] = get_num_sms()
return mm_template_options
def mm_args(
mat1,
mat2,
@ -154,6 +196,34 @@ def addmm_epilogue(dtype, alpha, beta):
return epilogue
def scale_mm_epilogue():
"""
Create an epilogue function that applies scaling to matrix multiplication result
using the given scale factors.
Args:
dtype: The data type of the output
scale_a: Scale factor for matrix A
scale_b: Scale factor for matrix B
Returns:
Epilogue function that takes the accumulator and applies scaling
"""
def epilogue(acc, inv_a_scale, inv_b_scale, bias=None):
# The epilogue function receives the accumulator (result of mat1 @ mat2)
# and applies the scaling factors
# In the original scaled_mm, we use inverse scales, so we multiply by them
mul_scales = V.ops.mul(inv_a_scale, inv_b_scale)
mul_acc = V.ops.mul(acc, mul_scales)
if bias is not None:
return V.ops.add(mul_acc, bias)
else:
return mul_acc
return epilogue
def _is_static_problem(layout: Layout) -> tuple[bool, bool]:
"""
Check if input tensors and output layout have static shapes and non-zero sizes.

View File

@ -1,608 +0,0 @@
import functools
import logging
from collections.abc import Sequence
from typing import Any, Optional
import sympy
import torch
from torch._dynamo.utils import counters
from torch._inductor.codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
from torch.utils._triton import has_triton_tma_device
from ..config import triton as triton_config
from ..ir import _IntLike, ChoiceCaller, get_device_type, Layout, StorageBox, TensorBox
from ..lowering import add_layout_constraint, constrain_to_fx_strides, register_lowering
from ..select_algorithm import (
autotune_select_algorithm,
ExternKernelChoice,
realize_inputs,
TritonTemplate,
)
from ..utils import (
get_num_sms,
get_tma_workspace_arg,
TMA_DESCRIPTOR_SIZE,
use_aten_gemm_kernels,
use_ck_gemm_template,
use_triton_template,
)
from ..virtualized import V
from .mm_common import (
_is_static_problem,
mm_args,
mm_grid,
persistent_mm_grid,
should_fallback_to_aten,
)
log = logging.getLogger(__name__)
aten = torch.ops.aten
load_scales = r"""
@triton.jit
def load_scales(a_scale_ptr, b_scale_ptr, SCALING_ROWWISE: tl.constexpr):
if SCALING_ROWWISE:
# For row-wise scaling, we'll return the pointers
return a_scale_ptr, b_scale_ptr
else:
# For per-tensor scaling, we'll load the scalar values
a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr)
return a_scale, b_scale
"""
apply_scaling = r"""
@triton.jit
def apply_scaling(
accumulator,
a_scale,
b_scale,
SCALING_ROWWISE: tl.constexpr,
offs_cm,
offs_cn,
M,
N,
stride_a_scale_m,
stride_b_scale_n,
):
if SCALING_ROWWISE:
# For row-wise scaling, we need to load the scales for each row/column
a_scales = tl.load(
a_scale + (offs_cm * stride_a_scale_m),
mask=offs_cm < M,
other=0.0,
)
b_scales = tl.load(
b_scale + (offs_cn * stride_b_scale_n),
mask=offs_cn < N,
other=0.0,
)
acc_scale = a_scales[:, None] * b_scales[None, :]
else:
# For per-tensor scaling, we can directly use the loaded scalar values
acc_scale = a_scale * b_scale
return accumulator * acc_scale
"""
device_tma = r"""
{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}}
M = {{size("A", 0)}}
N = {{size("B", 1)}}
K = {{size("A", 1)}}
if M * N == 0:
# early exit due to zero-size input(s)
return
stride_am = {{stride("A", 0)}}
stride_ak = {{stride("A", 1)}}
stride_bk = {{stride("B", 0)}}
stride_bn = {{stride("B", 1)}}
if SCALING_ROWWISE:
stride_a_scale_m = 1
stride_b_scale_n = 1
else:
stride_a_scale_m = 0
stride_b_scale_n = 0
start_pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
k_tiles = tl.cdiv(K, BLOCK_K)
num_tiles = num_pid_m * num_pid_n
workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE
a_desc_ptr = workspace_base
b_desc_ptr = workspace_base + TMA_SIZE
triton.language.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=a_desc_ptr,
global_address=A,
load_size=[BLOCK_M, BLOCK_K],
global_size=[M, K],
element_ty=A.dtype.element_ty,
)
triton.language.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=b_desc_ptr,
global_address=B,
load_size=[BLOCK_N, BLOCK_K],
global_size=[N, K],
element_ty=B.dtype.element_ty,
)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
tiles_per_SM = num_tiles // NUM_SMS
if start_pid < num_tiles % NUM_SMS:
tiles_per_SM += 1
tile_id = start_pid - NUM_SMS
ki = -1
pid_m = 0
pid_n = 0
offs_am = 0
offs_bn = 0
num_pid_in_group = GROUP_M * num_pid_n
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
a_scale, b_scale = load_scales(A_inverse_scale, B_inverse_scale, SCALING_ROWWISE)
for _ in range(0, k_tiles * tiles_per_SM):
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
if ki == 0:
tile_id += NUM_SMS
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
pid_m = first_pid_m + (tile_id % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_M
offs_bn = pid_n * BLOCK_N
offs_k = ki * BLOCK_K
a = tl._experimental_descriptor_load(
a_desc_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], A.dtype.element_ty
)
b = tl._experimental_descriptor_load(
b_desc_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], B.dtype.element_ty
)
if USE_FAST_ACCUM:
accumulator = tl.dot(a, b.T, accumulator)
else:
accumulator += tl.dot(a, b.T)
if ki == k_tiles - 1:
# Apply inverse scaling
offs_cm = offs_am + tl.arange(0, BLOCK_M)
offs_cn = offs_bn + tl.arange(0, BLOCK_N)
# Apply scaling
accumulator = apply_scaling(
accumulator,
a_scale,
b_scale,
SCALING_ROWWISE,
offs_cm,
offs_cn,
M,
N,
stride_a_scale_m,
stride_b_scale_n,
)
idx_m = offs_cm[:, None]
idx_n = offs_cn[None, :]
mask = (idx_m < M) & (idx_n < N)
# inductor generates a suffix
{{store_output(("idx_m", "idx_n"), "accumulator", "mask", indent_width=12)}}
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
"""
scaled_mm_device_tma_template = TritonTemplate(
name="scaled_mm_device_tma",
grid=persistent_mm_grid,
source=device_tma + load_scales + apply_scaling,
)
scaled_mm_template = TritonTemplate(
name="scaled_mm",
grid=mm_grid,
source=r"""
{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}}
M = {{size("A", 0)}}
N = {{size("B", 1)}}
K = {{size("A", 1)}}
if M * N == 0:
# early exit due to zero-size input(s)
return
stride_am = {{stride("A", 0)}}
stride_ak = {{stride("A", 1)}}
stride_bk = {{stride("B", 0)}}
stride_bn = {{stride("B", 1)}}
# based on triton.ops.matmul
pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = tl.arange(0, BLOCK_K)
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for k in range(K, 0, -BLOCK_K):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
a = tl.load(A, mask=rk[None, :] < k, other=0.)
b = tl.load(B, mask=rk[:, None] < k, other=0.)
if USE_FAST_ACCUM:
acc = tl.dot(a, b, acc, out_dtype=ACC_TYPE)
else:
acc += tl.dot(a, b, out_dtype=ACC_TYPE)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
if SCALING_ROWWISE:
inv_a_scale_row = tl.load(A_inverse_scale + rm, mask=rm < M)
inv_b_scale_row = tl.load(B_inverse_scale + rn, mask=rn < N)
inv_scale_row = inv_a_scale_row[:, None] * inv_b_scale_row[None, :]
acc *= inv_scale_row
else:
# for tensor-wise scaling, the scales are scalars
inv_a_scale = tl.load(A_inverse_scale)
inv_b_scale = tl.load(B_inverse_scale)
inv_scale = inv_a_scale * inv_b_scale
acc *= inv_scale
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
idx_m = rm[:, None]
idx_n = rn[None, :]
mask = (idx_m < M) & (idx_n < N)
# inductor generates a suffix
{{store_output(("idx_m", "idx_n"), "acc", "mask")}}
""",
)
# Inductor does not allow optional tensor input arguments currently (pass None as an
# input node to template choices), but since for _scaled_mm there is only one such arg
# (bias), work around by having a second template when bias is provided.
scaled_mm_bias_template = TritonTemplate(
name="scaled_mm_bias",
grid=mm_grid,
source=r"""
{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale", "bias_ptr")}}
M = {{size("A", 0)}}
N = {{size("B", 1)}}
K = {{size("A", 1)}}
if M * N == 0:
# early exit due to zero-size input(s)
return
stride_am = {{stride("A", 0)}}
stride_ak = {{stride("A", 1)}}
stride_bk = {{stride("B", 0)}}
stride_bn = {{stride("B", 1)}}
# based on triton.ops.matmul
pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = tl.arange(0, BLOCK_K)
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for k in range(K, 0, -BLOCK_K):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
a = tl.load(A, mask=rk[None, :] < k, other=0.)
b = tl.load(B, mask=rk[:, None] < k, other=0.)
if USE_FAST_ACCUM:
acc = tl.dot(a, b, acc, out_dtype=ACC_TYPE)
else:
acc += tl.dot(a, b, out_dtype=ACC_TYPE)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
if SCALING_ROWWISE:
inv_a_scale_row = tl.load(A_inverse_scale + rm, mask=rm < M)
inv_b_scale_row = tl.load(B_inverse_scale + rn, mask=rn < N)
inv_scale_row = inv_a_scale_row[:, None] * inv_b_scale_row[None, :]
acc *= inv_scale_row
else:
# for tensor-wise scaling, the scales are scalars
inv_a_scale = tl.load(A_inverse_scale)
inv_b_scale = tl.load(B_inverse_scale)
inv_scale = inv_a_scale * inv_b_scale
acc *= inv_scale
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
# bias
bias = tl.load(bias_ptr + rn, mask=rn < N)
acc += bias
idx_m = rm[:, None]
idx_n = rn[None, :]
mask = (idx_m < M) & (idx_n < N)
# inductor generates a suffix
{{store_output(("idx_m", "idx_n"), "acc", "mask")}}
""",
)
aten__fp8_mm = ExternKernelChoice(
torch._scaled_mm, "at::_scaled_mm_out", op_overload=aten._scaled_mm.out
)
def are_compatible_scales(size_a: Sequence[int], size_b: Sequence[int]) -> bool:
# Same sized scales are compatable
if len(size_a) == len(size_b):
return True
# Both need to be scalars or len(1) tensors
if len(size_a) <= 1 and len(size_b) <= 1:
return True
return False
def check_supported_striding(mat_a: TensorBox, mat_b: TensorBox) -> None:
def is_row_major(stride: Sequence[_IntLike]) -> bool:
return stride[1] == 1
def is_col_major(stride: Sequence[_IntLike]) -> bool:
return stride[0] == 1
def has_zero_dim(size: Sequence[_IntLike]) -> bool:
return bool(size[0] == 0 or size[1] == 0)
# Check mat_a (self) stride requirements
torch._check(
is_row_major(mat_a.get_stride()) or has_zero_dim(mat_a.get_size()),
lambda: f"mat_a must be row_major, got stride {mat_a.get_stride()}",
)
# Check mat_b stride requirements
torch._check(
is_col_major(mat_b.get_stride()) or has_zero_dim(mat_b.get_size()),
lambda: f"mat_b must be col_major, got stride {mat_b.get_stride()}",
)
def scaled_mm_options_device_tma( # type: ignore[no-untyped-def]
config, # triton.Config
sym_m: sympy.core.numbers.Integer,
sym_n: sympy.core.numbers.Integer,
sym_k: sympy.core.numbers.Integer,
layout: Layout,
scale_a: StorageBox,
scale_b: StorageBox,
use_fast_accum: bool,
) -> dict[str, Any]:
even_k_symbolic = (
sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"]
)
size_a, size_b = scale_a.get_size(), scale_b.get_size()
assert are_compatible_scales(size_a, size_b), (
"Expect scale_a and scale_b to be either both scalars (including single-element tensors) "
f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}."
)
return dict(
GROUP_M=8,
EVEN_K=even_k_symbolic,
ACC_TYPE="tl.float32",
USE_FAST_ACCUM=use_fast_accum,
num_stages=config.num_stages,
num_warps=config.num_warps,
# tensor-wise scaling if scalar scales
SCALING_ROWWISE=len(scale_a.get_size()) == 2,
TMA_SIZE=TMA_DESCRIPTOR_SIZE,
NUM_SMS=get_num_sms(),
**config.kwargs,
)
def scaled_mm_options( # type: ignore[no-untyped-def]
config, # triton.Config
sym_m: sympy.core.numbers.Integer,
sym_n: sympy.core.numbers.Integer,
sym_k: sympy.core.numbers.Integer,
layout: Layout,
scale_a: StorageBox,
scale_b: StorageBox,
use_fast_accum: bool,
) -> dict[str, Any]:
even_k_symbolic = (
sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"]
)
size_a, size_b = scale_a.get_size(), scale_b.get_size()
assert are_compatible_scales(size_a, size_b), (
"Expect scale_a and scale_b to be either both scalars (including single-element tensors) "
f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}."
)
return dict(
GROUP_M=8,
EVEN_K=even_k_symbolic,
ACC_TYPE="tl.float32",
USE_FAST_ACCUM=use_fast_accum,
num_stages=config.num_stages,
num_warps=config.num_warps,
# tensor-wise scaling if scalar scales
SCALING_ROWWISE=len(scale_a.get_size()) == 2,
**config.kwargs,
)
add_layout_constraint(aten._scaled_mm.default, constrain_to_fx_strides)
def use_persistent_tma(k: sympy.core.numbers.Integer, has_bias: bool) -> bool:
available = has_triton_tma_device() and triton_config.enable_persistent_tma_matmul
# _determine_swizzle_mode_2d requires BLOCK_K to be at least 32 contiguous bytes
# When K is 16, BLOCK_K = 16 and is not valid
min_k = k >= 32
return available and min_k and not has_bias
@register_lowering(aten._scaled_mm.default, type_promotion_kind=None) # type: ignore[misc]
def tuned_scaled_mm(
mat_a: TensorBox,
mat_b: TensorBox,
scale_a: TensorBox,
scale_b: TensorBox,
bias: Optional[TensorBox] = None,
scale_result: Optional[TensorBox] = None,
out_dtype: Optional[torch.dtype] = None,
use_fast_accum: bool = False,
layout: Optional[Layout] = None,
) -> TensorBox:
m, n, k, layout, mat_a, mat_b = mm_args(
mat_a, mat_b, layout=layout, out_dtype=out_dtype
)
# below is for getting an overview logging info of inductor mms
counters["aten_mm_info"][f"aten._scaled_mm.default_{m}_{n}_{k}"] += 1
log.info(
"Tuned aten._scaled_mm.default: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s",
m,
n,
k,
mat_a.get_dtype(),
mat_b.get_dtype(),
layout,
)
device_type = get_device_type(mat_a)
check_supported_striding(mat_a, mat_b)
scale_a, scale_b = realize_inputs(scale_a, scale_b)
input_nodes: tuple[Any, ...]
# workaround for Inductor not supporting optional tensor input arguments
if bias is None:
input_nodes = (mat_a, mat_b, scale_a, scale_b)
triton_template = scaled_mm_template
else:
bias = realize_inputs(bias)
input_nodes = (mat_a, mat_b, scale_a, scale_b, bias)
triton_template = scaled_mm_bias_template
aten_choice = aten__fp8_mm.bind(
input_nodes, layout, out_dtype=out_dtype, use_fast_accum=use_fast_accum
)
choices: list[ChoiceCaller] = []
if use_aten_gemm_kernels():
choices.append(aten_choice)
_, is_nonzero = _is_static_problem(layout)
scaled_mm_configs = V.choices.get_scaled_mm_configs(device_type)
scaled_persistent_mm_configs = V.choices.get_scaled_persistent_mm_configs(
device_type
)
if is_nonzero and use_triton_template(layout, enable_float8=True):
if use_persistent_tma(k, bias is not None):
for config in scaled_persistent_mm_configs(m, n, k):
kwargs = scaled_mm_options_device_tma(
config, m, n, k, layout, scale_a, scale_b, use_fast_accum
)
input_nodes = (mat_a, mat_b, scale_a, scale_b)
scaled_mm_device_tma_template.maybe_append_choice(
choices,
input_nodes=input_nodes,
layout=layout,
workspace_arg=get_tma_workspace_arg(
num_tma_descriptors=2,
device=mat_a.get_device(),
),
**kwargs,
)
else:
for config in scaled_mm_configs(m, n, k):
if k == 16 and config.kwargs["BLOCK_M"] >= 64:
continue # Triton crashes in this case
# On NVIDIA B200 GPUs, K dim must be >= 32 for tcgen05.mma.kind::f8f6f4.* PTX instruction to be valid
# source: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-shape
if using_b200() and k < 32:
continue
kwargs = scaled_mm_options(
config, m, n, k, layout, scale_a, scale_b, use_fast_accum
)
# possibly appends a TritonTemplateCaller to choices
triton_template.maybe_append_choice(
choices,
input_nodes=input_nodes,
layout=layout,
**kwargs,
)
if is_nonzero and use_ck_gemm_template(layout, m, n, k):
CKGemmTemplate.add_ck_gemm_choices(choices, layout, input_nodes)
if should_fallback_to_aten(choices):
return aten_choice.output_node()
return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout)
@functools.lru_cache
def using_b200() -> bool:
"""Returns true if the device is a NVIDIA B200, otherwise returns false."""
if not torch.cuda.is_available():
return False
# compute capability 10.0 or 10.0a is NVIDIA B200
device_properties = torch.cuda.get_device_properties(torch.cuda.current_device())
return device_properties.major == 10

View File

@ -1377,7 +1377,7 @@ def use_triton_tma_template(*matrices: IRNode) -> bool:
return False
dtype = x.get_dtype()
if dtype not in (torch.float16, torch.bfloat16):
if dtype not in (torch.float16, torch.bfloat16, torch.float8_e4m3fn):
return False
layout = x.get_layout()
@ -1388,6 +1388,12 @@ def use_triton_tma_template(*matrices: IRNode) -> bool:
inner_dim = layout.size[1]
if transposed:
inner_dim = layout.size[0]
if dtype == torch.float8_e4m3fn and V.graph.sizevars.statically_known_lt(
inner_dim, 32
):
return False
inner_bytes = inner_dim * dtype.itemsize
return V.graph.sizevars.statically_known_multiple_of(inner_bytes, TMA_ALIGNMENT)