[cutlass backend] Add (limited) bmm dynamic shape support (#152393)

Differential Revision: D73626732

In this PR, we add support for bmm dynamic shape, provided that the batch stride is the biggest in the stride for A, B, and D. For example, for A of size `(B, M, K)`, we support stride `(M*K, K, 1)` and `(M*K, 1, M)`. With this assumption, we can infer the batch stride from existing arguments.

The reason is we don't want to add 2-3 more runtime params. The concerns are complexity and possible perf regression, though we didn't verify the latter.

We can revisit this if there is a need for that.

We also remove `B = 1` for normal mm and addmm. We tested it and didn't see perf regression. But open to revisiting this as well.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152393
Approved by: https://github.com/ColinPeppler
This commit is contained in:
Henry Tsang 2025-04-30 04:36:24 +00:00 committed by PyTorch MergeBot
parent e5ea7911ea
commit ee2d104c05
5 changed files with 92 additions and 27 deletions

View File

@ -455,7 +455,7 @@ class TestCutlassBackend(TestCase):
torch.testing.assert_close(actual, expected)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@parametrize("dynamic", (False,))
@parametrize("dynamic", (False, True))
@parametrize("use_aoti", (False, True))
@parametrize("dtype", (torch.float16, torch.bfloat16))
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
@ -478,15 +478,25 @@ class TestCutlassBackend(TestCase):
# B, M, N, K
shapes = [
(10, 4096, 2048, 25728),
(20, 2048, 1024, 12864),
]
shapes = shapes[0:1] if not dynamic else shapes
inputs = [
(
torch.randn(B, M, K).cuda().to(dtype),
torch.randn(B, K, N).cuda().to(dtype),
torch.randn(B, N, K).cuda().to(dtype).permute(0, 2, 1),
)
for B, M, N, K in shapes
]
dynamic_shapes = (
{
"a": {0: Dim.DYNAMIC, 1: Dim.DYNAMIC, 2: Dim.DYNAMIC},
"b": {0: Dim.DYNAMIC, 1: Dim.DYNAMIC, 2: Dim.DYNAMIC},
}
if dynamic
else None
)
with config.patch(
{
"max_autotune": True,
@ -497,7 +507,9 @@ class TestCutlassBackend(TestCase):
):
expected = [model(*input) for input in inputs]
if use_aoti:
actual = AOTIRunnerUtil.run_multiple(model, inputs, dynamic_shapes=None)
actual = AOTIRunnerUtil.run_multiple(
model, inputs, dynamic_shapes=dynamic_shapes
)
else:
compiled_model = torch.compile(model, dynamic=dynamic)
actual = [compiled_model(*input) for input in inputs]

View File

@ -1,5 +1,7 @@
# mypy: allow-untyped-defs
import itertools
import logging
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union
@ -46,7 +48,7 @@ def _normalize_idx(index: int, total_length: int) -> int:
return index if index >= 0 else index + total_length
ValidLayoutSymbols = Literal["M", "N", "K", "lda", "ldb", "ldc", "ldd"]
ValidLayoutSymbols = Literal["M", "N", "K", "B", "lda", "ldb", "ldc", "ldd"]
ValidLayoutAttrs = Literal["size", "stride"]
@ -70,7 +72,7 @@ class CUDAKernel(Kernel):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.layout_args: dict[str, LayoutArg] = {}
self.layout_args: dict[str, list[LayoutArg]] = defaultdict(list)
# Mapping from arg name to IRNode.
self.named_nodes: dict[str, IRNode] = {}
@ -84,7 +86,9 @@ class CUDAKernel(Kernel):
self, node: IRNode, attr: ValidLayoutAttrs, dim: int
) -> Optional[LayoutArg]:
matches = [
arg for arg in self.layout_args.values() if arg.matches(node, attr, dim)
arg
for arg in itertools.chain.from_iterable(self.layout_args.values())
if arg.matches(node, attr, dim)
]
if len(matches) >= 1:
# Verify all matches have the same node, attribute, and dimension
@ -105,19 +109,27 @@ class CUDAKernel(Kernel):
self, symbol: ValidLayoutSymbols, node: IRNode, attr: ValidLayoutAttrs, dim: int
):
arg = LayoutArg(node, symbol, attr, dim)
self.layout_args.setdefault(symbol, arg)
self.layout_args[symbol].append(arg)
def init_layout_args(self) -> None:
X = self.named_nodes["X"]
W = self.named_nodes["W"]
Y = self.named_nodes["Y"]
Bias = self.named_nodes.get("Bias", None)
mdim = _normalize_idx(-2, len(X.get_size()))
ndim = _normalize_idx(-1, len(W.get_size()))
kdim = _normalize_idx(-1, len(X.get_size()))
self.add_layout_arg("M", X, "size", mdim)
self.add_layout_arg("N", W, "size", ndim)
self.add_layout_arg("K", X, "size", kdim)
x_mdim = _normalize_idx(-2, len(X.get_size()))
x_kdim = _normalize_idx(-1, len(X.get_size()))
w_kdim = _normalize_idx(-2, len(W.get_size()))
w_ndim = _normalize_idx(-1, len(W.get_size()))
y_mdim = _normalize_idx(-2, len(Y.get_size()))
y_ndim = _normalize_idx(-1, len(Y.get_size()))
self.add_layout_arg("M", X, "size", x_mdim)
self.add_layout_arg("K", X, "size", x_kdim)
self.add_layout_arg("K", W, "size", w_kdim)
self.add_layout_arg("N", W, "size", w_ndim)
self.add_layout_arg("M", Y, "size", y_mdim)
self.add_layout_arg("N", Y, "size", y_ndim)
if len(X.get_size()) > 2:
self.add_layout_arg("B", X, "size", 0)
lda_dim = self.find_ld_idx(X)
ldb_dim = self.find_ld_idx(W)
@ -145,11 +157,12 @@ class CUDAKernel(Kernel):
M = X.get_size()[mdim]
N = W.get_size()[ndim]
K = X.get_size()[kdim]
B = X.get_size()[0] if len(X.get_size()) > 2 else 1
LDA = get_ld(X)
LDB = get_ld(W)
LDC = get_ld(Bias) if Bias else 0
LDD = get_ld(Y)
return (M, N, K, LDA, LDB, LDC, LDD)
return (M, N, K, B, LDA, LDB, LDC, LDD)
@staticmethod
def find_ld_idx(node: IRNode) -> int:
@ -264,7 +277,7 @@ class CUDATemplateKernel(CUDAKernel):
self.init_layout_args()
size_args = [
f"const int {s}" for s in ("M", "N", "K", "lda", "ldb", "ldc", "ldd")
f"const int {s}" for s in ("M", "N", "K", "B", "lda", "ldb", "ldc", "ldd")
]
runtime_arg_decls = ",".join(
@ -461,6 +474,29 @@ class CUDATemplateKernel(CUDAKernel):
return str(stride)
return self.find_symbol(node, "stride", dim=index) or str(stride)
def batch_stride(self, node: IRNode, default_value: int = 0) -> str:
"""
Hook called from template code to get the batch stride of an arg.
Returns 0 if batch dim is not present.
This method assumes that batch stride is the largest stride.
"""
if node is None:
return str(default_value)
if len(node.get_size()) < 3:
return str(default_value)
batch_stride = node.get_stride()[0]
if V.graph.sizevars.statically_known_leq(batch_stride, 1):
return str(batch_stride)
return "{}*{}".format(
self.find_symbol(node, "size", dim=1) or node.get_size()[1],
self.find_symbol(node, "size", dim=2) or node.get_size()[2],
)
def row_or_column_stride(self, node: IRNode, default_value: int = 0) -> str:
"""
Hook called from template code to get the row or column stride of an arg.

View File

@ -39,7 +39,6 @@ GEMM_TEMPLATE_CUTLASS_3X = r"""
extern "C" {
PT_EXPORT {{kernel_call_signature}} {
try {
int B = {{kernel.size(Y, 0, -3, default_value=1)}};
using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator;
using coord_t = cutlass::gemm::GemmCoord::Index;
static cutlass::KernelHardwareInfo hw_info;
@ -110,13 +109,13 @@ GEMM_ARGS_CUTLASS_3X = r"""
{
{{template.cute_int(kernel.stride(X, -2), "stride_x0")}},
{{template.cute_int(kernel.stride(X, -1), "stride_x1")}},
{{template.cute_int(kernel.stride(X, -3), "batch_stride_x")}}
{{template.cute_int(kernel.batch_stride(X), "batch_stride_x")}}
}, // StrideA dA
{{template.cutlass_type_cast(W, kernel.ptr(W))}}, // ElementB const* ptr_B
{
{{template.cute_int(kernel.stride(W, -1), "stride_w1")}},
{{template.cute_int(kernel.stride(W, -2), "stride_w0")}},
{{template.cute_int(kernel.stride(W, -3), "batch_stride_w")}}
{{template.cute_int(kernel.batch_stride(W), "batch_stride_w")}}
}, // StrideB dB
}, // MainloopArguments mainloop
{{epilogue_arguments}},
@ -135,13 +134,13 @@ GEMM_ARGS_CUTLASS_3X_EPILOGUE = r"""
{
{{template.cute_int(kernel.stride(Bias, -2, 1), "stride_bias0")}},
{{template.cute_int(kernel.stride(Bias, -1, 1), "stride_bias1")}},
{{template.cute_int(kernel.stride(Bias, -3), "batch_stride_bias")}}
{{template.cute_int(kernel.batch_stride(Bias), "batch_stride_bias")}}
}, // StrideC dC
{{template.cutlass_type_cast(Y, kernel.ptr(Y))}}, // ElementD const* ptr_D
{
{{template.cute_int(kernel.stride(Y, -2), "stride_y0")}},
{{template.cute_int(kernel.stride(Y, -1), "stride_y1")}},
{{template.cute_int(kernel.stride(Y, -3), "batch_stride_y")}}
{{template.cute_int(kernel.batch_stride(Y), "batch_stride_y")}}
}, // StrideD dD
}, // EpilogueArguments epilogue
"""
@ -331,10 +330,11 @@ extern "C" int run_standalone(uint64_t seed, int repetitions) {
int M = {{kernel.get_layout_args()[0]}};
int N = {{kernel.get_layout_args()[1]}};
int K = {{kernel.get_layout_args()[2]}};
int lda = {{kernel.get_layout_args()[3]}};
int ldb = {{kernel.get_layout_args()[4]}};
int ldc = {{kernel.get_layout_args()[5]}};
int ldd = {{kernel.get_layout_args()[6]}};
int B = {{kernel.get_layout_args()[3]}};
int lda = {{kernel.get_layout_args()[4]}};
int ldb = {{kernel.get_layout_args()[5]}};
int ldc = {{kernel.get_layout_args()[6]}};
int ldd = {{kernel.get_layout_args()[7]}};
uint8_t swizzle = {{kernel.runtime_arg_values[0]}};
using ElementA = {{kernel.cutlass_dtype(X)}};
@ -1092,7 +1092,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
f"(({arg_type}){arg_name}_data.get())"
for arg_type, arg_name in zip(arg_types, arg_names)
]
return f"{kernel.kernel_name}({', '.join(arguments)}, M, N, K, lda, ldb, ldc, ldd, swizzle, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);" # noqa: B950
return f"{kernel.kernel_name}({', '.join(arguments)}, M, N, K, B, lda, ldb, ldc, ldd, swizzle, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);" # noqa: B950
class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):

View File

@ -23,6 +23,7 @@ from ..virtualized import V
from .mm_common import (
_is_static_problem,
addmm_epilogue,
is_batch_stride_largest,
mm_args,
mm_config_kwargs,
mm_options,
@ -194,8 +195,9 @@ def tuned_bmm(mat1, mat2, *, layout=None):
layout=layout,
**mm_options(config, m, n, k, layout),
)
static_shape, is_nonzero = _is_static_problem(layout)
if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k):
_, is_nonzero = _is_static_problem(layout)
batch_stride_largest = is_batch_stride_largest(mat1, mat2, layout)
if batch_stride_largest and is_nonzero and use_cutlass_template(layout, m, n, k):
from ..codegen.cuda.gemm_template import CUTLASS3xGemmTemplate
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2])

View File

@ -7,6 +7,7 @@ import sympy
import torch
from torch._inductor.select_algorithm import realize_inputs, SymbolicGridFn
from torch._inductor.utils import sympy_product
from torch._inductor.virtualized import V
from .. import config as inductor_config
@ -288,3 +289,17 @@ def check_supported_striding(mat_a: TensorBox, mat_b: TensorBox) -> None:
is_col_major(mat_b.get_stride()) or has_zero_dim(mat_b.get_size()),
lambda: f"mat_b must be col_major, got stride {mat_b.get_stride()}",
)
def is_batch_stride_largest(mat1, mat2, layout) -> bool:
"""
Checking if the batch stride is the largest in the stride.
"""
sizes = [mat1.get_size(), mat2.get_size(), layout.size]
strides = [mat1.get_stride(), mat2.get_stride(), layout.stride]
for size, stride in zip(sizes, strides):
assert len(size) == len(stride) == 3, "Expect 3D tensors"
if stride[0] != sympy_product(size[1:]):
return False
return True