diff --git a/test/inductor/test_cutlass_backend.py b/test/inductor/test_cutlass_backend.py index e2bd359a72c..cbfd9d30b60 100644 --- a/test/inductor/test_cutlass_backend.py +++ b/test/inductor/test_cutlass_backend.py @@ -455,7 +455,7 @@ class TestCutlassBackend(TestCase): torch.testing.assert_close(actual, expected) @unittest.skipIf(not SM90OrLater, "need sm_90") - @parametrize("dynamic", (False,)) + @parametrize("dynamic", (False, True)) @parametrize("use_aoti", (False, True)) @parametrize("dtype", (torch.float16, torch.bfloat16)) @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) @@ -478,15 +478,25 @@ class TestCutlassBackend(TestCase): # B, M, N, K shapes = [ (10, 4096, 2048, 25728), + (20, 2048, 1024, 12864), ] + shapes = shapes[0:1] if not dynamic else shapes inputs = [ ( torch.randn(B, M, K).cuda().to(dtype), - torch.randn(B, K, N).cuda().to(dtype), + torch.randn(B, N, K).cuda().to(dtype).permute(0, 2, 1), ) for B, M, N, K in shapes ] + dynamic_shapes = ( + { + "a": {0: Dim.DYNAMIC, 1: Dim.DYNAMIC, 2: Dim.DYNAMIC}, + "b": {0: Dim.DYNAMIC, 1: Dim.DYNAMIC, 2: Dim.DYNAMIC}, + } + if dynamic + else None + ) with config.patch( { "max_autotune": True, @@ -497,7 +507,9 @@ class TestCutlassBackend(TestCase): ): expected = [model(*input) for input in inputs] if use_aoti: - actual = AOTIRunnerUtil.run_multiple(model, inputs, dynamic_shapes=None) + actual = AOTIRunnerUtil.run_multiple( + model, inputs, dynamic_shapes=dynamic_shapes + ) else: compiled_model = torch.compile(model, dynamic=dynamic) actual = [compiled_model(*input) for input in inputs] diff --git a/torch/_inductor/codegen/cuda/cuda_kernel.py b/torch/_inductor/codegen/cuda/cuda_kernel.py index e6df1e90144..0f2d694cd07 100644 --- a/torch/_inductor/codegen/cuda/cuda_kernel.py +++ b/torch/_inductor/codegen/cuda/cuda_kernel.py @@ -1,5 +1,7 @@ # mypy: allow-untyped-defs +import itertools import logging +from collections import defaultdict from dataclasses import dataclass from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union @@ -46,7 +48,7 @@ def _normalize_idx(index: int, total_length: int) -> int: return index if index >= 0 else index + total_length -ValidLayoutSymbols = Literal["M", "N", "K", "lda", "ldb", "ldc", "ldd"] +ValidLayoutSymbols = Literal["M", "N", "K", "B", "lda", "ldb", "ldc", "ldd"] ValidLayoutAttrs = Literal["size", "stride"] @@ -70,7 +72,7 @@ class CUDAKernel(Kernel): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self.layout_args: dict[str, LayoutArg] = {} + self.layout_args: dict[str, list[LayoutArg]] = defaultdict(list) # Mapping from arg name to IRNode. self.named_nodes: dict[str, IRNode] = {} @@ -84,7 +86,9 @@ class CUDAKernel(Kernel): self, node: IRNode, attr: ValidLayoutAttrs, dim: int ) -> Optional[LayoutArg]: matches = [ - arg for arg in self.layout_args.values() if arg.matches(node, attr, dim) + arg + for arg in itertools.chain.from_iterable(self.layout_args.values()) + if arg.matches(node, attr, dim) ] if len(matches) >= 1: # Verify all matches have the same node, attribute, and dimension @@ -105,19 +109,27 @@ class CUDAKernel(Kernel): self, symbol: ValidLayoutSymbols, node: IRNode, attr: ValidLayoutAttrs, dim: int ): arg = LayoutArg(node, symbol, attr, dim) - self.layout_args.setdefault(symbol, arg) + self.layout_args[symbol].append(arg) def init_layout_args(self) -> None: X = self.named_nodes["X"] W = self.named_nodes["W"] Y = self.named_nodes["Y"] Bias = self.named_nodes.get("Bias", None) - mdim = _normalize_idx(-2, len(X.get_size())) - ndim = _normalize_idx(-1, len(W.get_size())) - kdim = _normalize_idx(-1, len(X.get_size())) - self.add_layout_arg("M", X, "size", mdim) - self.add_layout_arg("N", W, "size", ndim) - self.add_layout_arg("K", X, "size", kdim) + x_mdim = _normalize_idx(-2, len(X.get_size())) + x_kdim = _normalize_idx(-1, len(X.get_size())) + w_kdim = _normalize_idx(-2, len(W.get_size())) + w_ndim = _normalize_idx(-1, len(W.get_size())) + y_mdim = _normalize_idx(-2, len(Y.get_size())) + y_ndim = _normalize_idx(-1, len(Y.get_size())) + self.add_layout_arg("M", X, "size", x_mdim) + self.add_layout_arg("K", X, "size", x_kdim) + self.add_layout_arg("K", W, "size", w_kdim) + self.add_layout_arg("N", W, "size", w_ndim) + self.add_layout_arg("M", Y, "size", y_mdim) + self.add_layout_arg("N", Y, "size", y_ndim) + if len(X.get_size()) > 2: + self.add_layout_arg("B", X, "size", 0) lda_dim = self.find_ld_idx(X) ldb_dim = self.find_ld_idx(W) @@ -145,11 +157,12 @@ class CUDAKernel(Kernel): M = X.get_size()[mdim] N = W.get_size()[ndim] K = X.get_size()[kdim] + B = X.get_size()[0] if len(X.get_size()) > 2 else 1 LDA = get_ld(X) LDB = get_ld(W) LDC = get_ld(Bias) if Bias else 0 LDD = get_ld(Y) - return (M, N, K, LDA, LDB, LDC, LDD) + return (M, N, K, B, LDA, LDB, LDC, LDD) @staticmethod def find_ld_idx(node: IRNode) -> int: @@ -264,7 +277,7 @@ class CUDATemplateKernel(CUDAKernel): self.init_layout_args() size_args = [ - f"const int {s}" for s in ("M", "N", "K", "lda", "ldb", "ldc", "ldd") + f"const int {s}" for s in ("M", "N", "K", "B", "lda", "ldb", "ldc", "ldd") ] runtime_arg_decls = ",".join( @@ -461,6 +474,29 @@ class CUDATemplateKernel(CUDAKernel): return str(stride) return self.find_symbol(node, "stride", dim=index) or str(stride) + def batch_stride(self, node: IRNode, default_value: int = 0) -> str: + """ + Hook called from template code to get the batch stride of an arg. + Returns 0 if batch dim is not present. + + This method assumes that batch stride is the largest stride. + """ + + if node is None: + return str(default_value) + + if len(node.get_size()) < 3: + return str(default_value) + + batch_stride = node.get_stride()[0] + if V.graph.sizevars.statically_known_leq(batch_stride, 1): + return str(batch_stride) + + return "{}*{}".format( + self.find_symbol(node, "size", dim=1) or node.get_size()[1], + self.find_symbol(node, "size", dim=2) or node.get_size()[2], + ) + def row_or_column_stride(self, node: IRNode, default_value: int = 0) -> str: """ Hook called from template code to get the row or column stride of an arg. diff --git a/torch/_inductor/codegen/cuda/gemm_template.py b/torch/_inductor/codegen/cuda/gemm_template.py index ee9cb40851f..56b39d3a125 100644 --- a/torch/_inductor/codegen/cuda/gemm_template.py +++ b/torch/_inductor/codegen/cuda/gemm_template.py @@ -39,7 +39,6 @@ GEMM_TEMPLATE_CUTLASS_3X = r""" extern "C" { PT_EXPORT {{kernel_call_signature}} { try { - int B = {{kernel.size(Y, 0, -3, default_value=1)}}; using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator; using coord_t = cutlass::gemm::GemmCoord::Index; static cutlass::KernelHardwareInfo hw_info; @@ -110,13 +109,13 @@ GEMM_ARGS_CUTLASS_3X = r""" { {{template.cute_int(kernel.stride(X, -2), "stride_x0")}}, {{template.cute_int(kernel.stride(X, -1), "stride_x1")}}, - {{template.cute_int(kernel.stride(X, -3), "batch_stride_x")}} + {{template.cute_int(kernel.batch_stride(X), "batch_stride_x")}} }, // StrideA dA {{template.cutlass_type_cast(W, kernel.ptr(W))}}, // ElementB const* ptr_B { {{template.cute_int(kernel.stride(W, -1), "stride_w1")}}, {{template.cute_int(kernel.stride(W, -2), "stride_w0")}}, - {{template.cute_int(kernel.stride(W, -3), "batch_stride_w")}} + {{template.cute_int(kernel.batch_stride(W), "batch_stride_w")}} }, // StrideB dB }, // MainloopArguments mainloop {{epilogue_arguments}}, @@ -135,13 +134,13 @@ GEMM_ARGS_CUTLASS_3X_EPILOGUE = r""" { {{template.cute_int(kernel.stride(Bias, -2, 1), "stride_bias0")}}, {{template.cute_int(kernel.stride(Bias, -1, 1), "stride_bias1")}}, - {{template.cute_int(kernel.stride(Bias, -3), "batch_stride_bias")}} + {{template.cute_int(kernel.batch_stride(Bias), "batch_stride_bias")}} }, // StrideC dC {{template.cutlass_type_cast(Y, kernel.ptr(Y))}}, // ElementD const* ptr_D { {{template.cute_int(kernel.stride(Y, -2), "stride_y0")}}, {{template.cute_int(kernel.stride(Y, -1), "stride_y1")}}, - {{template.cute_int(kernel.stride(Y, -3), "batch_stride_y")}} + {{template.cute_int(kernel.batch_stride(Y), "batch_stride_y")}} }, // StrideD dD }, // EpilogueArguments epilogue """ @@ -331,10 +330,11 @@ extern "C" int run_standalone(uint64_t seed, int repetitions) { int M = {{kernel.get_layout_args()[0]}}; int N = {{kernel.get_layout_args()[1]}}; int K = {{kernel.get_layout_args()[2]}}; - int lda = {{kernel.get_layout_args()[3]}}; - int ldb = {{kernel.get_layout_args()[4]}}; - int ldc = {{kernel.get_layout_args()[5]}}; - int ldd = {{kernel.get_layout_args()[6]}}; + int B = {{kernel.get_layout_args()[3]}}; + int lda = {{kernel.get_layout_args()[4]}}; + int ldb = {{kernel.get_layout_args()[5]}}; + int ldc = {{kernel.get_layout_args()[6]}}; + int ldd = {{kernel.get_layout_args()[7]}}; uint8_t swizzle = {{kernel.runtime_arg_values[0]}}; using ElementA = {{kernel.cutlass_dtype(X)}}; @@ -1092,7 +1092,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC): f"(({arg_type}){arg_name}_data.get())" for arg_type, arg_name in zip(arg_types, arg_names) ] - return f"{kernel.kernel_name}({', '.join(arguments)}, M, N, K, lda, ldb, ldc, ldd, swizzle, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);" # noqa: B950 + return f"{kernel.kernel_name}({', '.join(arguments)}, M, N, K, B, lda, ldb, ldc, ldd, swizzle, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);" # noqa: B950 class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate): diff --git a/torch/_inductor/kernel/bmm.py b/torch/_inductor/kernel/bmm.py index cd074e2c36d..a23536d55bb 100644 --- a/torch/_inductor/kernel/bmm.py +++ b/torch/_inductor/kernel/bmm.py @@ -23,6 +23,7 @@ from ..virtualized import V from .mm_common import ( _is_static_problem, addmm_epilogue, + is_batch_stride_largest, mm_args, mm_config_kwargs, mm_options, @@ -194,8 +195,9 @@ def tuned_bmm(mat1, mat2, *, layout=None): layout=layout, **mm_options(config, m, n, k, layout), ) - static_shape, is_nonzero = _is_static_problem(layout) - if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k): + _, is_nonzero = _is_static_problem(layout) + batch_stride_largest = is_batch_stride_largest(mat1, mat2, layout) + if batch_stride_largest and is_nonzero and use_cutlass_template(layout, m, n, k): from ..codegen.cuda.gemm_template import CUTLASS3xGemmTemplate CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2]) diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index acd5f85851f..3dae93adde1 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -7,6 +7,7 @@ import sympy import torch from torch._inductor.select_algorithm import realize_inputs, SymbolicGridFn +from torch._inductor.utils import sympy_product from torch._inductor.virtualized import V from .. import config as inductor_config @@ -288,3 +289,17 @@ def check_supported_striding(mat_a: TensorBox, mat_b: TensorBox) -> None: is_col_major(mat_b.get_stride()) or has_zero_dim(mat_b.get_size()), lambda: f"mat_b must be col_major, got stride {mat_b.get_stride()}", ) + + +def is_batch_stride_largest(mat1, mat2, layout) -> bool: + """ + Checking if the batch stride is the largest in the stride. + """ + sizes = [mat1.get_size(), mat2.get_size(), layout.size] + strides = [mat1.get_stride(), mat2.get_stride(), layout.stride] + for size, stride in zip(sizes, strides): + assert len(size) == len(stride) == 3, "Expect 3D tensors" + if stride[0] != sympy_product(size[1:]): + return False + + return True