From 0b62465b99b23cb4afcd07424676cce34a676041 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 19 Jun 2025 14:25:38 +0000 Subject: [PATCH] Revert "Refine alignment check along dynamic dimension for grouped MMs (#155466)" This reverts commit 830a335a7da5fec00395d440ba568749cb4e2e9e. Reverted https://github.com/pytorch/pytorch/pull/155466 on behalf of https://github.com/atalman due to breaks internal builds ([comment](https://github.com/pytorch/pytorch/pull/155466#issuecomment-2988285117)) --- aten/src/ATen/native/cuda/Blas.cpp | 48 ++++++----------- aten/src/ATen/native/cuda/GroupMMCommon.cuh | 41 ++------------- test/test_matmul_cuda.py | 57 ++++----------------- torch/_inductor/kernel/mm_scaled_grouped.py | 5 +- torch/_meta_registrations.py | 38 +++++++------- 5 files changed, 53 insertions(+), 136 deletions(-) diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 1834839bb6e..efa4af02a44 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -36,7 +36,6 @@ #include #include #include -#include #include #include #include @@ -1482,49 +1481,29 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, } namespace { - at::Tensor create_grouped_gemm_output_tensor(const Tensor& mat_a, + c10::SmallVector compute_grouped_gemm_output_size(const Tensor& mat_a, const Tensor& mat_b, - const std::optional& offs, - std::optional out_dtype + const std::optional& offs ) { - c10::SmallVector out_size; const bool a_is_2d = mat_a.dim() == 2; const bool b_is_2d = mat_b.dim() == 2; if (a_is_2d) { if (b_is_2d) { - out_size = {offs->size(0), mat_a.size(0), mat_b.size(1)}; + return {offs->size(0), mat_a.size(0), mat_b.size(1)}; } else { TORCH_CHECK(offs->size(0) == mat_b.size(0), "matrix batch sizes have to match"); - out_size = {mat_a.size(0), mat_b.size(-1)}; + return {mat_a.size(0), mat_b.size(-1)}; } } else { if (b_is_2d) { // this case is not actually encountered for MoE gemms TORCH_CHECK(offs->size(0) == mat_a.size(0), "matrix batch sizes have to match"); - out_size = {mat_a.size(1), mat_b.size(1)}; + return {mat_a.size(1), mat_b.size(1)}; } else { // regular bmm TORCH_CHECK(mat_a.size(0) == mat_b.size(0), "batched dimension has to match"); - out_size = {mat_a.size(0), mat_a.size(1), mat_b.size(-1)}; + return {mat_a.size(0), mat_a.size(1), mat_b.size(-1)}; } } - - const auto out_dtype_ = out_dtype.value_or(kBFloat16); - TORCH_CHECK(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm"); - - // For TMA transfers, strides of output tensor have to be either - // 1, or aligned to 16 bytes. - const auto last_dim = out_size.size() - 1; - const auto alignment = 16 / c10::elementSize(out_dtype_); - const int64_t size_padded = (out_size[last_dim] + alignment - 1) / alignment * alignment; - std::vector out_stride; - if (a_is_2d != b_is_2d) { - out_stride = {size_padded, 1}; - } else { - out_stride = {out_size[1] * size_padded, size_padded, 1}; - } - auto out = at::empty_strided(out_size, out_stride, mat_a.options().dtype(out_dtype_)); - - return out; } bool check_valid_strides_and_return_transposed(const Tensor& mat) { @@ -1540,7 +1519,7 @@ namespace { TORCH_CHECK(tensor_strides[end_dim - 1] % alignment == 0, "strides should be multiple of 16 bytes"); return false; } else { - TORCH_CHECK(false, "Invalid strides/sizes, got ", mat.strides(), " for strides and ", mat.sizes(), " for sizes"); + TORCH_CHECK(false, "Tensor should have a contiguous dimension and not be self-overlapping, got ", mat.strides(), " for strides and ", mat.sizes(), " for sizes"); } } @@ -1648,7 +1627,11 @@ bool use_fast_accum) { check_scale(mat_a, scale_a, 0 ,0, scale_multiplier); check_scale(mat_b, scale_b, 1, 1, scale_multiplier); - Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype); + const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type()); + TORCH_CHECK(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm"); + const auto out_size = compute_grouped_gemm_output_size(mat_a, mat_b, offs); + Tensor out = at::empty(out_size, mat_a.options().dtype(out_dtype_)); + at::cuda::detail::f8f8bf16_grouped_mm( mat_a, @@ -1684,7 +1667,6 @@ std::optional out_dtype) { TORCH_CHECK(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d"); const bool a_is_2d = mat_a.dim() == 2; const bool b_is_2d = mat_b.dim() == 2; - // check that the strides are valid, the fn will throw an error if not check_valid_strides_and_return_transposed(mat_a); check_valid_strides_and_return_transposed(mat_b); @@ -1694,10 +1676,12 @@ std::optional out_dtype) { TORCH_CHECK(offs->dim() == 1, "offs has to be 1D"); TORCH_CHECK(offs->dtype() == at::kInt, "Offsets have to be int32"); } + const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type()); + TORCH_CHECK(out_dtype_ == kBFloat16, "Only bf16 high output type is supported for grouped gemm"); TORCH_CHECK(!bias.has_value(), "Bias not supported yet"); - Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype); - + const auto out_size = compute_grouped_gemm_output_size(mat_a, mat_b, offs); + Tensor out = at::empty(out_size, mat_a.options().dtype(out_dtype_)); at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out); return out; #else diff --git a/aten/src/ATen/native/cuda/GroupMMCommon.cuh b/aten/src/ATen/native/cuda/GroupMMCommon.cuh index a94bf9bcecf..a0474b7ad17 100644 --- a/aten/src/ATen/native/cuda/GroupMMCommon.cuh +++ b/aten/src/ATen/native/cuda/GroupMMCommon.cuh @@ -47,42 +47,10 @@ __global__ void prepare_grouped_gemm_data( if (offs != nullptr) { int32_t start = tid == 0 ? 0 : offs[tid - 1]; delta = offs[tid] - start; - if (K < 0) { - // CUTLASS cannot handle delta=0 here. - CUDA_KERNEL_ASSERT(delta >0 && "expected ofsets to be greater than 0\n"); - } else { - CUDA_KERNEL_ASSERT(delta >=0 && "expected ofsets to be greater or equal 0\n"); - } - - // TMA transfers require global memory tensor addresses to be - // aligned to 16 bytes. - if (tid < blockDim.x - 1) { - // Check this requirement for input tensors, in case group - // addresses are increased along the dynamic dimension. - if ((K < 0 && a_row_major) || // 2D/2D: check along K dimension - (M < 0 && !a_row_major)) { // 3D/2D: check along N dimension - int align = 128 / cutlass::sizeof_bits::value; - CUDA_KERNEL_ASSERT( - delta % align == 0 && - "expected input tensor dynamic dimension byte size to be non-negative multiple of 16\n"); - } - if ((K < 0 && !b_row_major) || // 2D/2D: check along K dimension - (N < 0 && b_row_major)) { // 3D/2D: check along N dimension - int align = 128 / cutlass::sizeof_bits::value; - CUDA_KERNEL_ASSERT( - delta % align == 0 && - "expected input tensor dynamic dimension byte size to be non-negative multiple of 16\n"); - } - - // Check the same requirement for output tensor (that is always - // contiguous, and in row-major layout). - if (N < 0) { - int align = 128 / cutlass::sizeof_bits::value; - CUDA_KERNEL_ASSERT( - delta % align == 0 && - "expected output tensor dynamic dimension byte size to be non-negative multiple of 16\n"); - } - } + int align = 16 / sizeof(DtypeA); + CUDA_KERNEL_ASSERT( + delta >=0 && delta % align == 0 && + "expected dynamic dimension byte size to be non-negative multiple of 16 \n"); } int64_t lda, ldb, ldoutput; if (M < 0) { @@ -113,6 +81,7 @@ __global__ void prepare_grouped_gemm_data( } else if (K < 0) { // A, B is 2d, output is 3d K = delta; + CUDA_KERNEL_ASSERT(delta > 0 && "can't handle K=0"); lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1]; ldb = b_row_major ? tensor_StrideB[0] : tensor_StrideB[1]; ldoutput = tensor_StrideOutput[1]; diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 4e64c807425..163ebc4c447 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -315,7 +315,7 @@ class TestMatmulCuda(TestCase): def test_grouped_gemm_2d_2d(self, strided, a_row_major, b_row_major, use_torch_compile): device = "cuda" dtype = torch.bfloat16 - m, n, k, n_groups = 16, 32, 64, 4 + m, n, k, n_groups = 16, 32, 64, 4 # all sizes have to be divisible by 16 if a_row_major: a = torch.randn(m, k * n_groups + k * int(strided), device=device, dtype=dtype)[:, :k * n_groups] else: @@ -382,9 +382,6 @@ class TestMatmulCuda(TestCase): b_contig = b if b_row_major else b.transpose(-2, -1) self.assertTrue(b_contig.is_contiguous() is not strided) for check_zero_size in (False, True): - if check_zero_size and n_groups <= 1: - continue - a.grad = None b.grad = None offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32) @@ -487,9 +484,6 @@ class TestMatmulCuda(TestCase): b_contig = b if b_row_major else b.transpose(-2, -1) self.assertTrue(b_contig.is_contiguous() is not strided) for check_zero_size in (False, True): - if check_zero_size and n_groups <= 1: - continue - offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32) if check_zero_size: offs[0] = offs[1] @@ -1651,27 +1645,17 @@ class TestFP8Matmul(TestCase): for a, b, ascale, bscale, out in zip(alist, blist, ascalelist, bscalelist, outlist): out_ref = torch._scaled_mm(a, b.t(), ascale.view(-1, 1), bscale.view(1, -1), out_dtype=torch.bfloat16, use_fast_accum=use_fast_accum) - self.assertEqual(out, out_ref, atol=5e-2, rtol=5e-4) - - # Testing only _scaled_grouped_mm() with multiple shapes, as - # _scaled_mm() already has more combinations of parameters than - # _scaled_grouped_mm(), for supporing more than one inputs layout - # combinations. + self.assertEqual(out, out_ref, atol=8e-2, rtol=8e-4) @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @xfailIfSM100OrLater @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") - @parametrize( - "n_groups, m, n, k", - [(2, 1, 16, 16), - (4, 16, 16, 16)], - name_fn=lambda n_groups, m, n, k: f"{n_groups}_{m}_{n}_{k}", - ) @parametrize("fast_accum", [False, True]) @parametrize("strided", [False, True]) @parametrize("use_torch_compile", [False, True]) - def test_scaled_grouped_gemm_2d_2d(self, n_groups, m, n, k, fast_accum, strided, use_torch_compile): + def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided, use_torch_compile): device = "cuda" + m, n, k, n_groups = 16, 32, 64, 4 # all sizes have to be divisible by 16 a = torch.randn(m, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups] b = torch.randn(n, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups] scale_a = torch.rand(m * n_groups, device=device, dtype=torch.float32) @@ -1701,26 +1685,18 @@ class TestFP8Matmul(TestCase): @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @xfailIfSM100OrLater @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") - @parametrize( - "n_groups, m, n, k", - [(2, 1, 16, 16), - (4, 16, 16, 16)], - name_fn=lambda n_groups, m, n, k: f"{n_groups}_{m}_{n}_{k}", - ) @parametrize("fast_accum", [False, True]) @parametrize("strided", [False, True]) @parametrize("use_torch_compile", [False, True]) - def test_scaled_grouped_gemm_2d_3d(self, n_groups, m, n, k, fast_accum, strided, use_torch_compile): + def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, use_torch_compile): device = "cuda" s_int = int(strided) + m, n, k, n_groups = 16, 32, 64, 4 a = torch.randn(m * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k] b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k] self.assertTrue(a.is_contiguous() is not strided) self.assertTrue(b.is_contiguous() is not strided) for check_zero_size in (True, False): - if check_zero_size and n_groups <= 1: - continue - offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32) if check_zero_size: offs[0] = offs[1] @@ -1751,18 +1727,13 @@ class TestFP8Matmul(TestCase): @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @xfailIfSM100OrLater @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") - @parametrize( - "n_groups, m, n, k", - [(2, 1, 16, 16), - (4, 16, 16, 16)], - name_fn=lambda n_groups, m, n, k: f"{n_groups}_{m}_{n}_{k}", - ) @parametrize("fast_accum", [False, True]) @parametrize("strided", [False, True]) @parametrize("use_torch_compile", [False, True]) - def test_scaled_grouped_gemm_3d_3d(self, n_groups, m, n, k, fast_accum, strided, use_torch_compile): + def test_scaled_grouped_gemm_3d_3d(self, fast_accum, strided, use_torch_compile): device = "cuda" s_int = int(strided) + m, n, k, n_groups = 16, 32, 64, 4 a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k] b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k] self.assertTrue(a.is_contiguous() is not strided) @@ -1786,18 +1757,13 @@ class TestFP8Matmul(TestCase): @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @xfailIfSM100OrLater @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") - @parametrize( - "n_groups, m, n, k", - [(2, 1, 16, 16), - (4, 16, 16, 16)], - name_fn=lambda n_groups, m, n, k: f"{n_groups}_{m}_{n}_{k}", - ) @parametrize("fast_accum", [False, True]) @parametrize("strided", [False, True]) @parametrize("use_torch_compile", [False, True]) - def test_scaled_grouped_gemm_3d_2d(self, n_groups, m, n, k, fast_accum, strided, use_torch_compile): + def test_scaled_grouped_gemm_3d_2d(self, fast_accum, strided, use_torch_compile): device = "cuda" s_int = int(strided) + m, n, k, n_groups = 16, 32, 64, 4 a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k] b = torch.randn(n * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k] self.assertTrue(a.is_contiguous() is not strided) @@ -1805,9 +1771,6 @@ class TestFP8Matmul(TestCase): scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m) scale_b = torch.rand(n_groups * n, device="cuda", dtype=torch.float32) for check_zero_size in (True, False): - if check_zero_size and n_groups <= 1: - continue - offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32) if check_zero_size: offs[0] = offs[1] diff --git a/torch/_inductor/kernel/mm_scaled_grouped.py b/torch/_inductor/kernel/mm_scaled_grouped.py index e83bd8c4b56..7e6e937d12d 100644 --- a/torch/_inductor/kernel/mm_scaled_grouped.py +++ b/torch/_inductor/kernel/mm_scaled_grouped.py @@ -120,6 +120,7 @@ def early_config_prune(g, m, configs, named_args): return pruned_configs +# Copied from fbgemm grouped_gemm.py triton_grouped_mm_source = r""" {%- if SCALED %} {%- if A_IS_2D or B_IS_2D %} @@ -670,7 +671,7 @@ def _tuned_grouped_mm_common( ) -@register_lowering(aten._grouped_mm.default, type_promotion_kind=None) +@register_lowering(aten._grouped_mm, type_promotion_kind=None) def tuned_grouped_mm( mat_a: TensorBox, mat_b: TensorBox, @@ -682,7 +683,7 @@ def tuned_grouped_mm( """Auto-tuning for _grouped_mm() operator.""" return _tuned_grouped_mm_common( - "aten._grouped_mm.default", + "aten._grouped_mm", "grouped_mm", aten__grouped_mm, triton_grouped_mm_template, diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 997d1fabad8..7a782ea6f3a 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -7467,39 +7467,28 @@ def sigmoid(self: Tensor) -> Tensor: return torch.empty_like(self, dtype=result_dtype) -def _create_grouped_mm_output_tensor(mat1, mat2, offs, out_dtype): +def _compute_grouped_mm_output_size(mat1, mat2, offs): mat1_is_2d = mat1.dim() == 2 mat2_is_2d = mat2.dim() == 2 if mat1_is_2d: if mat2_is_2d: - out_size = [offs.size(0), mat1.size(0), mat2.size(1)] + return offs.size(0), mat1.size(0), mat2.size(1) else: torch._check( offs.size(0) == mat2.size(0), "matrix batch sizes have to match" ) - out_size = [mat1.size(0), mat2.size(-1)] + return mat1.size(0), mat2.size(-1) else: if mat2_is_2d: torch._check( offs.size(0) == mat1.size(0), "matrix batch sizes have to match" ) - out_size = [mat1.size(1), mat2.size(1)] + return mat1.size(1), mat2.size(1) else: # regular bmm torch._check(mat1.size(0) == mat2.size(0), "batched dimension has to match") - out_size = [mat1.size(0), mat1.size(1), mat2.size(-1)] - - out_dtype = out_dtype or mat1.dtype - - alignment = 16 // out_dtype.itemsize - size_padded = (out_size[-1] + alignment - 1) // alignment * alignment - if mat1_is_2d == mat2_is_2d: - out_stride = [out_size[1] * size_padded, size_padded, 1] - else: - out_stride = [size_padded, 1] - out = torch.empty_strided(out_size, out_stride, dtype=out_dtype, device=mat1.device) - return out + return mat1.size(0), mat1.size(1), mat2.size(-1) def _meta_grouped_mm_common( @@ -7542,6 +7531,15 @@ def _meta_grouped_mm_common( mat_a_is_2d = mat_a.dim() == 2 mat_b_is_2d = mat_b.dim() == 2 + torch._check( + mat_a.shape[-1] % 16 == 0, + lambda: f"Expected mat_a.shape[-1] to be divisible by 16, but got mat_a.shape[-1]={mat_a.shape[1]}", + ) + torch._check( + mat_b.shape[-2] % 16 == 0 and mat_b.shape[-1] % 16 == 0, + lambda: f"Expected mat_b.shape[-2] and mat_b.shape[-1] to be both divisble by 16 but got {mat_b.shape[-2]} and {mat_b.shape[-1]}", # noqa: B950 + ) + if scaled: def is_row_major(mat): @@ -7563,7 +7561,7 @@ def _meta_grouped_mm_common( def check_valid_strides(mat_name, mat): end_dim = mat.dim() - 1 - alignment = 16 // mat.element_size() + alignment = 16 / mat.element_size() mat_stride = mat.stride() if mat_stride[end_dim - 1] == 1 and mat_stride[end_dim] >= max( 1, mat.shape[end_dim - 1] @@ -7582,7 +7580,7 @@ def _meta_grouped_mm_common( else: torch._check( False, - lambda: f"Invalid strides/sizes, got {mat_stride} for strides and {mat.shape} for sizes.", # noqa: B950 + lambda: f"Expected {mat_name} to have a contiguous dimension and not be mat_a-overlapping, got {mat_stride} for strides and {mat.shape} for sizes.", # noqa: B950 ) check_valid_strides("mat_a", mat_a) @@ -7667,7 +7665,9 @@ def _meta_grouped_mm_common( lambda: "If output dtype provided, it must be torch.bfloat16.", ) - return _create_grouped_mm_output_tensor(mat_a, mat_b, offs, out_dtype) + out_size = _compute_grouped_mm_output_size(mat_a, mat_b, offs) + out_dtype = out_dtype or mat_a.dtype + return torch.empty(out_size, dtype=out_dtype, device=mat_a.device) @register_meta(aten._grouped_mm)