mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[cutlass backend] Add (limited) bmm dynamic shape support (#152393)
Differential Revision: D73626732 In this PR, we add support for bmm dynamic shape, provided that the batch stride is the biggest in the stride for A, B, and D. For example, for A of size `(B, M, K)`, we support stride `(M*K, K, 1)` and `(M*K, 1, M)`. With this assumption, we can infer the batch stride from existing arguments. The reason is we don't want to add 2-3 more runtime params. The concerns are complexity and possible perf regression, though we didn't verify the latter. We can revisit this if there is a need for that. We also remove `B = 1` for normal mm and addmm. We tested it and didn't see perf regression. But open to revisiting this as well. Pull Request resolved: https://github.com/pytorch/pytorch/pull/152393 Approved by: https://github.com/ColinPeppler
This commit is contained in:
parent
e5ea7911ea
commit
ee2d104c05
|
|
@ -455,7 +455,7 @@ class TestCutlassBackend(TestCase):
|
|||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@parametrize("dynamic", (False,))
|
||||
@parametrize("dynamic", (False, True))
|
||||
@parametrize("use_aoti", (False, True))
|
||||
@parametrize("dtype", (torch.float16, torch.bfloat16))
|
||||
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
|
||||
|
|
@ -478,15 +478,25 @@ class TestCutlassBackend(TestCase):
|
|||
# B, M, N, K
|
||||
shapes = [
|
||||
(10, 4096, 2048, 25728),
|
||||
(20, 2048, 1024, 12864),
|
||||
]
|
||||
shapes = shapes[0:1] if not dynamic else shapes
|
||||
|
||||
inputs = [
|
||||
(
|
||||
torch.randn(B, M, K).cuda().to(dtype),
|
||||
torch.randn(B, K, N).cuda().to(dtype),
|
||||
torch.randn(B, N, K).cuda().to(dtype).permute(0, 2, 1),
|
||||
)
|
||||
for B, M, N, K in shapes
|
||||
]
|
||||
dynamic_shapes = (
|
||||
{
|
||||
"a": {0: Dim.DYNAMIC, 1: Dim.DYNAMIC, 2: Dim.DYNAMIC},
|
||||
"b": {0: Dim.DYNAMIC, 1: Dim.DYNAMIC, 2: Dim.DYNAMIC},
|
||||
}
|
||||
if dynamic
|
||||
else None
|
||||
)
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
|
|
@ -497,7 +507,9 @@ class TestCutlassBackend(TestCase):
|
|||
):
|
||||
expected = [model(*input) for input in inputs]
|
||||
if use_aoti:
|
||||
actual = AOTIRunnerUtil.run_multiple(model, inputs, dynamic_shapes=None)
|
||||
actual = AOTIRunnerUtil.run_multiple(
|
||||
model, inputs, dynamic_shapes=dynamic_shapes
|
||||
)
|
||||
else:
|
||||
compiled_model = torch.compile(model, dynamic=dynamic)
|
||||
actual = [compiled_model(*input) for input in inputs]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import itertools
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union
|
||||
|
||||
|
|
@ -46,7 +48,7 @@ def _normalize_idx(index: int, total_length: int) -> int:
|
|||
return index if index >= 0 else index + total_length
|
||||
|
||||
|
||||
ValidLayoutSymbols = Literal["M", "N", "K", "lda", "ldb", "ldc", "ldd"]
|
||||
ValidLayoutSymbols = Literal["M", "N", "K", "B", "lda", "ldb", "ldc", "ldd"]
|
||||
ValidLayoutAttrs = Literal["size", "stride"]
|
||||
|
||||
|
||||
|
|
@ -70,7 +72,7 @@ class CUDAKernel(Kernel):
|
|||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.layout_args: dict[str, LayoutArg] = {}
|
||||
self.layout_args: dict[str, list[LayoutArg]] = defaultdict(list)
|
||||
# Mapping from arg name to IRNode.
|
||||
self.named_nodes: dict[str, IRNode] = {}
|
||||
|
||||
|
|
@ -84,7 +86,9 @@ class CUDAKernel(Kernel):
|
|||
self, node: IRNode, attr: ValidLayoutAttrs, dim: int
|
||||
) -> Optional[LayoutArg]:
|
||||
matches = [
|
||||
arg for arg in self.layout_args.values() if arg.matches(node, attr, dim)
|
||||
arg
|
||||
for arg in itertools.chain.from_iterable(self.layout_args.values())
|
||||
if arg.matches(node, attr, dim)
|
||||
]
|
||||
if len(matches) >= 1:
|
||||
# Verify all matches have the same node, attribute, and dimension
|
||||
|
|
@ -105,19 +109,27 @@ class CUDAKernel(Kernel):
|
|||
self, symbol: ValidLayoutSymbols, node: IRNode, attr: ValidLayoutAttrs, dim: int
|
||||
):
|
||||
arg = LayoutArg(node, symbol, attr, dim)
|
||||
self.layout_args.setdefault(symbol, arg)
|
||||
self.layout_args[symbol].append(arg)
|
||||
|
||||
def init_layout_args(self) -> None:
|
||||
X = self.named_nodes["X"]
|
||||
W = self.named_nodes["W"]
|
||||
Y = self.named_nodes["Y"]
|
||||
Bias = self.named_nodes.get("Bias", None)
|
||||
mdim = _normalize_idx(-2, len(X.get_size()))
|
||||
ndim = _normalize_idx(-1, len(W.get_size()))
|
||||
kdim = _normalize_idx(-1, len(X.get_size()))
|
||||
self.add_layout_arg("M", X, "size", mdim)
|
||||
self.add_layout_arg("N", W, "size", ndim)
|
||||
self.add_layout_arg("K", X, "size", kdim)
|
||||
x_mdim = _normalize_idx(-2, len(X.get_size()))
|
||||
x_kdim = _normalize_idx(-1, len(X.get_size()))
|
||||
w_kdim = _normalize_idx(-2, len(W.get_size()))
|
||||
w_ndim = _normalize_idx(-1, len(W.get_size()))
|
||||
y_mdim = _normalize_idx(-2, len(Y.get_size()))
|
||||
y_ndim = _normalize_idx(-1, len(Y.get_size()))
|
||||
self.add_layout_arg("M", X, "size", x_mdim)
|
||||
self.add_layout_arg("K", X, "size", x_kdim)
|
||||
self.add_layout_arg("K", W, "size", w_kdim)
|
||||
self.add_layout_arg("N", W, "size", w_ndim)
|
||||
self.add_layout_arg("M", Y, "size", y_mdim)
|
||||
self.add_layout_arg("N", Y, "size", y_ndim)
|
||||
if len(X.get_size()) > 2:
|
||||
self.add_layout_arg("B", X, "size", 0)
|
||||
|
||||
lda_dim = self.find_ld_idx(X)
|
||||
ldb_dim = self.find_ld_idx(W)
|
||||
|
|
@ -145,11 +157,12 @@ class CUDAKernel(Kernel):
|
|||
M = X.get_size()[mdim]
|
||||
N = W.get_size()[ndim]
|
||||
K = X.get_size()[kdim]
|
||||
B = X.get_size()[0] if len(X.get_size()) > 2 else 1
|
||||
LDA = get_ld(X)
|
||||
LDB = get_ld(W)
|
||||
LDC = get_ld(Bias) if Bias else 0
|
||||
LDD = get_ld(Y)
|
||||
return (M, N, K, LDA, LDB, LDC, LDD)
|
||||
return (M, N, K, B, LDA, LDB, LDC, LDD)
|
||||
|
||||
@staticmethod
|
||||
def find_ld_idx(node: IRNode) -> int:
|
||||
|
|
@ -264,7 +277,7 @@ class CUDATemplateKernel(CUDAKernel):
|
|||
|
||||
self.init_layout_args()
|
||||
size_args = [
|
||||
f"const int {s}" for s in ("M", "N", "K", "lda", "ldb", "ldc", "ldd")
|
||||
f"const int {s}" for s in ("M", "N", "K", "B", "lda", "ldb", "ldc", "ldd")
|
||||
]
|
||||
|
||||
runtime_arg_decls = ",".join(
|
||||
|
|
@ -461,6 +474,29 @@ class CUDATemplateKernel(CUDAKernel):
|
|||
return str(stride)
|
||||
return self.find_symbol(node, "stride", dim=index) or str(stride)
|
||||
|
||||
def batch_stride(self, node: IRNode, default_value: int = 0) -> str:
|
||||
"""
|
||||
Hook called from template code to get the batch stride of an arg.
|
||||
Returns 0 if batch dim is not present.
|
||||
|
||||
This method assumes that batch stride is the largest stride.
|
||||
"""
|
||||
|
||||
if node is None:
|
||||
return str(default_value)
|
||||
|
||||
if len(node.get_size()) < 3:
|
||||
return str(default_value)
|
||||
|
||||
batch_stride = node.get_stride()[0]
|
||||
if V.graph.sizevars.statically_known_leq(batch_stride, 1):
|
||||
return str(batch_stride)
|
||||
|
||||
return "{}*{}".format(
|
||||
self.find_symbol(node, "size", dim=1) or node.get_size()[1],
|
||||
self.find_symbol(node, "size", dim=2) or node.get_size()[2],
|
||||
)
|
||||
|
||||
def row_or_column_stride(self, node: IRNode, default_value: int = 0) -> str:
|
||||
"""
|
||||
Hook called from template code to get the row or column stride of an arg.
|
||||
|
|
|
|||
|
|
@ -39,7 +39,6 @@ GEMM_TEMPLATE_CUTLASS_3X = r"""
|
|||
extern "C" {
|
||||
PT_EXPORT {{kernel_call_signature}} {
|
||||
try {
|
||||
int B = {{kernel.size(Y, 0, -3, default_value=1)}};
|
||||
using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator;
|
||||
using coord_t = cutlass::gemm::GemmCoord::Index;
|
||||
static cutlass::KernelHardwareInfo hw_info;
|
||||
|
|
@ -110,13 +109,13 @@ GEMM_ARGS_CUTLASS_3X = r"""
|
|||
{
|
||||
{{template.cute_int(kernel.stride(X, -2), "stride_x0")}},
|
||||
{{template.cute_int(kernel.stride(X, -1), "stride_x1")}},
|
||||
{{template.cute_int(kernel.stride(X, -3), "batch_stride_x")}}
|
||||
{{template.cute_int(kernel.batch_stride(X), "batch_stride_x")}}
|
||||
}, // StrideA dA
|
||||
{{template.cutlass_type_cast(W, kernel.ptr(W))}}, // ElementB const* ptr_B
|
||||
{
|
||||
{{template.cute_int(kernel.stride(W, -1), "stride_w1")}},
|
||||
{{template.cute_int(kernel.stride(W, -2), "stride_w0")}},
|
||||
{{template.cute_int(kernel.stride(W, -3), "batch_stride_w")}}
|
||||
{{template.cute_int(kernel.batch_stride(W), "batch_stride_w")}}
|
||||
}, // StrideB dB
|
||||
}, // MainloopArguments mainloop
|
||||
{{epilogue_arguments}},
|
||||
|
|
@ -135,13 +134,13 @@ GEMM_ARGS_CUTLASS_3X_EPILOGUE = r"""
|
|||
{
|
||||
{{template.cute_int(kernel.stride(Bias, -2, 1), "stride_bias0")}},
|
||||
{{template.cute_int(kernel.stride(Bias, -1, 1), "stride_bias1")}},
|
||||
{{template.cute_int(kernel.stride(Bias, -3), "batch_stride_bias")}}
|
||||
{{template.cute_int(kernel.batch_stride(Bias), "batch_stride_bias")}}
|
||||
}, // StrideC dC
|
||||
{{template.cutlass_type_cast(Y, kernel.ptr(Y))}}, // ElementD const* ptr_D
|
||||
{
|
||||
{{template.cute_int(kernel.stride(Y, -2), "stride_y0")}},
|
||||
{{template.cute_int(kernel.stride(Y, -1), "stride_y1")}},
|
||||
{{template.cute_int(kernel.stride(Y, -3), "batch_stride_y")}}
|
||||
{{template.cute_int(kernel.batch_stride(Y), "batch_stride_y")}}
|
||||
}, // StrideD dD
|
||||
}, // EpilogueArguments epilogue
|
||||
"""
|
||||
|
|
@ -331,10 +330,11 @@ extern "C" int run_standalone(uint64_t seed, int repetitions) {
|
|||
int M = {{kernel.get_layout_args()[0]}};
|
||||
int N = {{kernel.get_layout_args()[1]}};
|
||||
int K = {{kernel.get_layout_args()[2]}};
|
||||
int lda = {{kernel.get_layout_args()[3]}};
|
||||
int ldb = {{kernel.get_layout_args()[4]}};
|
||||
int ldc = {{kernel.get_layout_args()[5]}};
|
||||
int ldd = {{kernel.get_layout_args()[6]}};
|
||||
int B = {{kernel.get_layout_args()[3]}};
|
||||
int lda = {{kernel.get_layout_args()[4]}};
|
||||
int ldb = {{kernel.get_layout_args()[5]}};
|
||||
int ldc = {{kernel.get_layout_args()[6]}};
|
||||
int ldd = {{kernel.get_layout_args()[7]}};
|
||||
uint8_t swizzle = {{kernel.runtime_arg_values[0]}};
|
||||
|
||||
using ElementA = {{kernel.cutlass_dtype(X)}};
|
||||
|
|
@ -1092,7 +1092,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
|||
f"(({arg_type}){arg_name}_data.get())"
|
||||
for arg_type, arg_name in zip(arg_types, arg_names)
|
||||
]
|
||||
return f"{kernel.kernel_name}({', '.join(arguments)}, M, N, K, lda, ldb, ldc, ldd, swizzle, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);" # noqa: B950
|
||||
return f"{kernel.kernel_name}({', '.join(arguments)}, M, N, K, B, lda, ldb, ldc, ldd, swizzle, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);" # noqa: B950
|
||||
|
||||
|
||||
class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from ..virtualized import V
|
|||
from .mm_common import (
|
||||
_is_static_problem,
|
||||
addmm_epilogue,
|
||||
is_batch_stride_largest,
|
||||
mm_args,
|
||||
mm_config_kwargs,
|
||||
mm_options,
|
||||
|
|
@ -194,8 +195,9 @@ def tuned_bmm(mat1, mat2, *, layout=None):
|
|||
layout=layout,
|
||||
**mm_options(config, m, n, k, layout),
|
||||
)
|
||||
static_shape, is_nonzero = _is_static_problem(layout)
|
||||
if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k):
|
||||
_, is_nonzero = _is_static_problem(layout)
|
||||
batch_stride_largest = is_batch_stride_largest(mat1, mat2, layout)
|
||||
if batch_stride_largest and is_nonzero and use_cutlass_template(layout, m, n, k):
|
||||
from ..codegen.cuda.gemm_template import CUTLASS3xGemmTemplate
|
||||
|
||||
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2])
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import sympy
|
|||
|
||||
import torch
|
||||
from torch._inductor.select_algorithm import realize_inputs, SymbolicGridFn
|
||||
from torch._inductor.utils import sympy_product
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
from .. import config as inductor_config
|
||||
|
|
@ -288,3 +289,17 @@ def check_supported_striding(mat_a: TensorBox, mat_b: TensorBox) -> None:
|
|||
is_col_major(mat_b.get_stride()) or has_zero_dim(mat_b.get_size()),
|
||||
lambda: f"mat_b must be col_major, got stride {mat_b.get_stride()}",
|
||||
)
|
||||
|
||||
|
||||
def is_batch_stride_largest(mat1, mat2, layout) -> bool:
|
||||
"""
|
||||
Checking if the batch stride is the largest in the stride.
|
||||
"""
|
||||
sizes = [mat1.get_size(), mat2.get_size(), layout.size]
|
||||
strides = [mat1.get_stride(), mat2.get_stride(), layout.stride]
|
||||
for size, stride in zip(sizes, strides):
|
||||
assert len(size) == len(stride) == 3, "Expect 3D tensors"
|
||||
if stride[0] != sympy_product(size[1:]):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user