[cutlass backend] Add (limited) bmm dynamic shape support (#152393)

Differential Revision: D73626732

In this PR, we add support for bmm dynamic shape, provided that the batch stride is the biggest in the stride for A, B, and D. For example, for A of size `(B, M, K)`, we support stride `(M*K, K, 1)` and `(M*K, 1, M)`. With this assumption, we can infer the batch stride from existing arguments.

The reason is we don't want to add 2-3 more runtime params. The concerns are complexity and possible perf regression, though we didn't verify the latter.

We can revisit this if there is a need for that.

We also remove `B = 1` for normal mm and addmm. We tested it and didn't see perf regression. But open to revisiting this as well.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152393
Approved by: https://github.com/ColinPeppler
This commit is contained in:
Henry Tsang 2025-04-30 04:36:24 +00:00 committed by PyTorch MergeBot
parent e5ea7911ea
commit ee2d104c05
5 changed files with 92 additions and 27 deletions

View File

@ -455,7 +455,7 @@ class TestCutlassBackend(TestCase):
torch.testing.assert_close(actual, expected) torch.testing.assert_close(actual, expected)
@unittest.skipIf(not SM90OrLater, "need sm_90") @unittest.skipIf(not SM90OrLater, "need sm_90")
@parametrize("dynamic", (False,)) @parametrize("dynamic", (False, True))
@parametrize("use_aoti", (False, True)) @parametrize("use_aoti", (False, True))
@parametrize("dtype", (torch.float16, torch.bfloat16)) @parametrize("dtype", (torch.float16, torch.bfloat16))
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
@ -478,15 +478,25 @@ class TestCutlassBackend(TestCase):
# B, M, N, K # B, M, N, K
shapes = [ shapes = [
(10, 4096, 2048, 25728), (10, 4096, 2048, 25728),
(20, 2048, 1024, 12864),
] ]
shapes = shapes[0:1] if not dynamic else shapes
inputs = [ inputs = [
( (
torch.randn(B, M, K).cuda().to(dtype), torch.randn(B, M, K).cuda().to(dtype),
torch.randn(B, K, N).cuda().to(dtype), torch.randn(B, N, K).cuda().to(dtype).permute(0, 2, 1),
) )
for B, M, N, K in shapes for B, M, N, K in shapes
] ]
dynamic_shapes = (
{
"a": {0: Dim.DYNAMIC, 1: Dim.DYNAMIC, 2: Dim.DYNAMIC},
"b": {0: Dim.DYNAMIC, 1: Dim.DYNAMIC, 2: Dim.DYNAMIC},
}
if dynamic
else None
)
with config.patch( with config.patch(
{ {
"max_autotune": True, "max_autotune": True,
@ -497,7 +507,9 @@ class TestCutlassBackend(TestCase):
): ):
expected = [model(*input) for input in inputs] expected = [model(*input) for input in inputs]
if use_aoti: if use_aoti:
actual = AOTIRunnerUtil.run_multiple(model, inputs, dynamic_shapes=None) actual = AOTIRunnerUtil.run_multiple(
model, inputs, dynamic_shapes=dynamic_shapes
)
else: else:
compiled_model = torch.compile(model, dynamic=dynamic) compiled_model = torch.compile(model, dynamic=dynamic)
actual = [compiled_model(*input) for input in inputs] actual = [compiled_model(*input) for input in inputs]

View File

@ -1,5 +1,7 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import itertools
import logging import logging
from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union
@ -46,7 +48,7 @@ def _normalize_idx(index: int, total_length: int) -> int:
return index if index >= 0 else index + total_length return index if index >= 0 else index + total_length
ValidLayoutSymbols = Literal["M", "N", "K", "lda", "ldb", "ldc", "ldd"] ValidLayoutSymbols = Literal["M", "N", "K", "B", "lda", "ldb", "ldc", "ldd"]
ValidLayoutAttrs = Literal["size", "stride"] ValidLayoutAttrs = Literal["size", "stride"]
@ -70,7 +72,7 @@ class CUDAKernel(Kernel):
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.layout_args: dict[str, LayoutArg] = {} self.layout_args: dict[str, list[LayoutArg]] = defaultdict(list)
# Mapping from arg name to IRNode. # Mapping from arg name to IRNode.
self.named_nodes: dict[str, IRNode] = {} self.named_nodes: dict[str, IRNode] = {}
@ -84,7 +86,9 @@ class CUDAKernel(Kernel):
self, node: IRNode, attr: ValidLayoutAttrs, dim: int self, node: IRNode, attr: ValidLayoutAttrs, dim: int
) -> Optional[LayoutArg]: ) -> Optional[LayoutArg]:
matches = [ matches = [
arg for arg in self.layout_args.values() if arg.matches(node, attr, dim) arg
for arg in itertools.chain.from_iterable(self.layout_args.values())
if arg.matches(node, attr, dim)
] ]
if len(matches) >= 1: if len(matches) >= 1:
# Verify all matches have the same node, attribute, and dimension # Verify all matches have the same node, attribute, and dimension
@ -105,19 +109,27 @@ class CUDAKernel(Kernel):
self, symbol: ValidLayoutSymbols, node: IRNode, attr: ValidLayoutAttrs, dim: int self, symbol: ValidLayoutSymbols, node: IRNode, attr: ValidLayoutAttrs, dim: int
): ):
arg = LayoutArg(node, symbol, attr, dim) arg = LayoutArg(node, symbol, attr, dim)
self.layout_args.setdefault(symbol, arg) self.layout_args[symbol].append(arg)
def init_layout_args(self) -> None: def init_layout_args(self) -> None:
X = self.named_nodes["X"] X = self.named_nodes["X"]
W = self.named_nodes["W"] W = self.named_nodes["W"]
Y = self.named_nodes["Y"] Y = self.named_nodes["Y"]
Bias = self.named_nodes.get("Bias", None) Bias = self.named_nodes.get("Bias", None)
mdim = _normalize_idx(-2, len(X.get_size())) x_mdim = _normalize_idx(-2, len(X.get_size()))
ndim = _normalize_idx(-1, len(W.get_size())) x_kdim = _normalize_idx(-1, len(X.get_size()))
kdim = _normalize_idx(-1, len(X.get_size())) w_kdim = _normalize_idx(-2, len(W.get_size()))
self.add_layout_arg("M", X, "size", mdim) w_ndim = _normalize_idx(-1, len(W.get_size()))
self.add_layout_arg("N", W, "size", ndim) y_mdim = _normalize_idx(-2, len(Y.get_size()))
self.add_layout_arg("K", X, "size", kdim) y_ndim = _normalize_idx(-1, len(Y.get_size()))
self.add_layout_arg("M", X, "size", x_mdim)
self.add_layout_arg("K", X, "size", x_kdim)
self.add_layout_arg("K", W, "size", w_kdim)
self.add_layout_arg("N", W, "size", w_ndim)
self.add_layout_arg("M", Y, "size", y_mdim)
self.add_layout_arg("N", Y, "size", y_ndim)
if len(X.get_size()) > 2:
self.add_layout_arg("B", X, "size", 0)
lda_dim = self.find_ld_idx(X) lda_dim = self.find_ld_idx(X)
ldb_dim = self.find_ld_idx(W) ldb_dim = self.find_ld_idx(W)
@ -145,11 +157,12 @@ class CUDAKernel(Kernel):
M = X.get_size()[mdim] M = X.get_size()[mdim]
N = W.get_size()[ndim] N = W.get_size()[ndim]
K = X.get_size()[kdim] K = X.get_size()[kdim]
B = X.get_size()[0] if len(X.get_size()) > 2 else 1
LDA = get_ld(X) LDA = get_ld(X)
LDB = get_ld(W) LDB = get_ld(W)
LDC = get_ld(Bias) if Bias else 0 LDC = get_ld(Bias) if Bias else 0
LDD = get_ld(Y) LDD = get_ld(Y)
return (M, N, K, LDA, LDB, LDC, LDD) return (M, N, K, B, LDA, LDB, LDC, LDD)
@staticmethod @staticmethod
def find_ld_idx(node: IRNode) -> int: def find_ld_idx(node: IRNode) -> int:
@ -264,7 +277,7 @@ class CUDATemplateKernel(CUDAKernel):
self.init_layout_args() self.init_layout_args()
size_args = [ size_args = [
f"const int {s}" for s in ("M", "N", "K", "lda", "ldb", "ldc", "ldd") f"const int {s}" for s in ("M", "N", "K", "B", "lda", "ldb", "ldc", "ldd")
] ]
runtime_arg_decls = ",".join( runtime_arg_decls = ",".join(
@ -461,6 +474,29 @@ class CUDATemplateKernel(CUDAKernel):
return str(stride) return str(stride)
return self.find_symbol(node, "stride", dim=index) or str(stride) return self.find_symbol(node, "stride", dim=index) or str(stride)
def batch_stride(self, node: IRNode, default_value: int = 0) -> str:
"""
Hook called from template code to get the batch stride of an arg.
Returns 0 if batch dim is not present.
This method assumes that batch stride is the largest stride.
"""
if node is None:
return str(default_value)
if len(node.get_size()) < 3:
return str(default_value)
batch_stride = node.get_stride()[0]
if V.graph.sizevars.statically_known_leq(batch_stride, 1):
return str(batch_stride)
return "{}*{}".format(
self.find_symbol(node, "size", dim=1) or node.get_size()[1],
self.find_symbol(node, "size", dim=2) or node.get_size()[2],
)
def row_or_column_stride(self, node: IRNode, default_value: int = 0) -> str: def row_or_column_stride(self, node: IRNode, default_value: int = 0) -> str:
""" """
Hook called from template code to get the row or column stride of an arg. Hook called from template code to get the row or column stride of an arg.

View File

@ -39,7 +39,6 @@ GEMM_TEMPLATE_CUTLASS_3X = r"""
extern "C" { extern "C" {
PT_EXPORT {{kernel_call_signature}} { PT_EXPORT {{kernel_call_signature}} {
try { try {
int B = {{kernel.size(Y, 0, -3, default_value=1)}};
using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator; using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator;
using coord_t = cutlass::gemm::GemmCoord::Index; using coord_t = cutlass::gemm::GemmCoord::Index;
static cutlass::KernelHardwareInfo hw_info; static cutlass::KernelHardwareInfo hw_info;
@ -110,13 +109,13 @@ GEMM_ARGS_CUTLASS_3X = r"""
{ {
{{template.cute_int(kernel.stride(X, -2), "stride_x0")}}, {{template.cute_int(kernel.stride(X, -2), "stride_x0")}},
{{template.cute_int(kernel.stride(X, -1), "stride_x1")}}, {{template.cute_int(kernel.stride(X, -1), "stride_x1")}},
{{template.cute_int(kernel.stride(X, -3), "batch_stride_x")}} {{template.cute_int(kernel.batch_stride(X), "batch_stride_x")}}
}, // StrideA dA }, // StrideA dA
{{template.cutlass_type_cast(W, kernel.ptr(W))}}, // ElementB const* ptr_B {{template.cutlass_type_cast(W, kernel.ptr(W))}}, // ElementB const* ptr_B
{ {
{{template.cute_int(kernel.stride(W, -1), "stride_w1")}}, {{template.cute_int(kernel.stride(W, -1), "stride_w1")}},
{{template.cute_int(kernel.stride(W, -2), "stride_w0")}}, {{template.cute_int(kernel.stride(W, -2), "stride_w0")}},
{{template.cute_int(kernel.stride(W, -3), "batch_stride_w")}} {{template.cute_int(kernel.batch_stride(W), "batch_stride_w")}}
}, // StrideB dB }, // StrideB dB
}, // MainloopArguments mainloop }, // MainloopArguments mainloop
{{epilogue_arguments}}, {{epilogue_arguments}},
@ -135,13 +134,13 @@ GEMM_ARGS_CUTLASS_3X_EPILOGUE = r"""
{ {
{{template.cute_int(kernel.stride(Bias, -2, 1), "stride_bias0")}}, {{template.cute_int(kernel.stride(Bias, -2, 1), "stride_bias0")}},
{{template.cute_int(kernel.stride(Bias, -1, 1), "stride_bias1")}}, {{template.cute_int(kernel.stride(Bias, -1, 1), "stride_bias1")}},
{{template.cute_int(kernel.stride(Bias, -3), "batch_stride_bias")}} {{template.cute_int(kernel.batch_stride(Bias), "batch_stride_bias")}}
}, // StrideC dC }, // StrideC dC
{{template.cutlass_type_cast(Y, kernel.ptr(Y))}}, // ElementD const* ptr_D {{template.cutlass_type_cast(Y, kernel.ptr(Y))}}, // ElementD const* ptr_D
{ {
{{template.cute_int(kernel.stride(Y, -2), "stride_y0")}}, {{template.cute_int(kernel.stride(Y, -2), "stride_y0")}},
{{template.cute_int(kernel.stride(Y, -1), "stride_y1")}}, {{template.cute_int(kernel.stride(Y, -1), "stride_y1")}},
{{template.cute_int(kernel.stride(Y, -3), "batch_stride_y")}} {{template.cute_int(kernel.batch_stride(Y), "batch_stride_y")}}
}, // StrideD dD }, // StrideD dD
}, // EpilogueArguments epilogue }, // EpilogueArguments epilogue
""" """
@ -331,10 +330,11 @@ extern "C" int run_standalone(uint64_t seed, int repetitions) {
int M = {{kernel.get_layout_args()[0]}}; int M = {{kernel.get_layout_args()[0]}};
int N = {{kernel.get_layout_args()[1]}}; int N = {{kernel.get_layout_args()[1]}};
int K = {{kernel.get_layout_args()[2]}}; int K = {{kernel.get_layout_args()[2]}};
int lda = {{kernel.get_layout_args()[3]}}; int B = {{kernel.get_layout_args()[3]}};
int ldb = {{kernel.get_layout_args()[4]}}; int lda = {{kernel.get_layout_args()[4]}};
int ldc = {{kernel.get_layout_args()[5]}}; int ldb = {{kernel.get_layout_args()[5]}};
int ldd = {{kernel.get_layout_args()[6]}}; int ldc = {{kernel.get_layout_args()[6]}};
int ldd = {{kernel.get_layout_args()[7]}};
uint8_t swizzle = {{kernel.runtime_arg_values[0]}}; uint8_t swizzle = {{kernel.runtime_arg_values[0]}};
using ElementA = {{kernel.cutlass_dtype(X)}}; using ElementA = {{kernel.cutlass_dtype(X)}};
@ -1092,7 +1092,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
f"(({arg_type}){arg_name}_data.get())" f"(({arg_type}){arg_name}_data.get())"
for arg_type, arg_name in zip(arg_types, arg_names) for arg_type, arg_name in zip(arg_types, arg_names)
] ]
return f"{kernel.kernel_name}({', '.join(arguments)}, M, N, K, lda, ldb, ldc, ldd, swizzle, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);" # noqa: B950 return f"{kernel.kernel_name}({', '.join(arguments)}, M, N, K, B, lda, ldb, ldc, ldd, swizzle, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);" # noqa: B950
class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate): class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):

View File

@ -23,6 +23,7 @@ from ..virtualized import V
from .mm_common import ( from .mm_common import (
_is_static_problem, _is_static_problem,
addmm_epilogue, addmm_epilogue,
is_batch_stride_largest,
mm_args, mm_args,
mm_config_kwargs, mm_config_kwargs,
mm_options, mm_options,
@ -194,8 +195,9 @@ def tuned_bmm(mat1, mat2, *, layout=None):
layout=layout, layout=layout,
**mm_options(config, m, n, k, layout), **mm_options(config, m, n, k, layout),
) )
static_shape, is_nonzero = _is_static_problem(layout) _, is_nonzero = _is_static_problem(layout)
if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k): batch_stride_largest = is_batch_stride_largest(mat1, mat2, layout)
if batch_stride_largest and is_nonzero and use_cutlass_template(layout, m, n, k):
from ..codegen.cuda.gemm_template import CUTLASS3xGemmTemplate from ..codegen.cuda.gemm_template import CUTLASS3xGemmTemplate
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2]) CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2])

View File

@ -7,6 +7,7 @@ import sympy
import torch import torch
from torch._inductor.select_algorithm import realize_inputs, SymbolicGridFn from torch._inductor.select_algorithm import realize_inputs, SymbolicGridFn
from torch._inductor.utils import sympy_product
from torch._inductor.virtualized import V from torch._inductor.virtualized import V
from .. import config as inductor_config from .. import config as inductor_config
@ -288,3 +289,17 @@ def check_supported_striding(mat_a: TensorBox, mat_b: TensorBox) -> None:
is_col_major(mat_b.get_stride()) or has_zero_dim(mat_b.get_size()), is_col_major(mat_b.get_stride()) or has_zero_dim(mat_b.get_size()),
lambda: f"mat_b must be col_major, got stride {mat_b.get_stride()}", lambda: f"mat_b must be col_major, got stride {mat_b.get_stride()}",
) )
def is_batch_stride_largest(mat1, mat2, layout) -> bool:
"""
Checking if the batch stride is the largest in the stride.
"""
sizes = [mat1.get_size(), mat2.get_size(), layout.size]
strides = [mat1.get_stride(), mat2.get_stride(), layout.stride]
for size, stride in zip(sizes, strides):
assert len(size) == len(stride) == 3, "Expect 3D tensors"
if stride[0] != sympy_product(size[1:]):
return False
return True