Revert "Refine alignment check along dynamic dimension for grouped MMs (#155466)"

This reverts commit 830a335a7d.

Reverted https://github.com/pytorch/pytorch/pull/155466 on behalf of https://github.com/atalman due to breaks internal builds ([comment](https://github.com/pytorch/pytorch/pull/155466#issuecomment-2988285117))
This commit is contained in:
PyTorch MergeBot 2025-06-19 14:25:38 +00:00
parent fec8af8b98
commit 0b62465b99
5 changed files with 53 additions and 136 deletions

View File

@ -36,7 +36,6 @@
#include <ATen/ops/copy_native.h>
#include <ATen/ops/dot_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_strided.h>
#include <ATen/ops/gelu.h>
#include <ATen/ops/max.h>
#include <ATen/ops/mm_native.h>
@ -1482,49 +1481,29 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
}
namespace {
at::Tensor create_grouped_gemm_output_tensor(const Tensor& mat_a,
c10::SmallVector<int64_t, 3> compute_grouped_gemm_output_size(const Tensor& mat_a,
const Tensor& mat_b,
const std::optional<at::Tensor>& offs,
std::optional<c10::ScalarType> out_dtype
const std::optional<at::Tensor>& offs
) {
c10::SmallVector<int64_t, 3> out_size;
const bool a_is_2d = mat_a.dim() == 2;
const bool b_is_2d = mat_b.dim() == 2;
if (a_is_2d) {
if (b_is_2d) {
out_size = {offs->size(0), mat_a.size(0), mat_b.size(1)};
return {offs->size(0), mat_a.size(0), mat_b.size(1)};
} else {
TORCH_CHECK(offs->size(0) == mat_b.size(0), "matrix batch sizes have to match");
out_size = {mat_a.size(0), mat_b.size(-1)};
return {mat_a.size(0), mat_b.size(-1)};
}
} else {
if (b_is_2d) {
// this case is not actually encountered for MoE gemms
TORCH_CHECK(offs->size(0) == mat_a.size(0), "matrix batch sizes have to match");
out_size = {mat_a.size(1), mat_b.size(1)};
return {mat_a.size(1), mat_b.size(1)};
} else { // regular bmm
TORCH_CHECK(mat_a.size(0) == mat_b.size(0), "batched dimension has to match");
out_size = {mat_a.size(0), mat_a.size(1), mat_b.size(-1)};
return {mat_a.size(0), mat_a.size(1), mat_b.size(-1)};
}
}
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
TORCH_CHECK(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
// For TMA transfers, strides of output tensor have to be either
// 1, or aligned to 16 bytes.
const auto last_dim = out_size.size() - 1;
const auto alignment = 16 / c10::elementSize(out_dtype_);
const int64_t size_padded = (out_size[last_dim] + alignment - 1) / alignment * alignment;
std::vector<int64_t> out_stride;
if (a_is_2d != b_is_2d) {
out_stride = {size_padded, 1};
} else {
out_stride = {out_size[1] * size_padded, size_padded, 1};
}
auto out = at::empty_strided(out_size, out_stride, mat_a.options().dtype(out_dtype_));
return out;
}
bool check_valid_strides_and_return_transposed(const Tensor& mat) {
@ -1540,7 +1519,7 @@ namespace {
TORCH_CHECK(tensor_strides[end_dim - 1] % alignment == 0, "strides should be multiple of 16 bytes");
return false;
} else {
TORCH_CHECK(false, "Invalid strides/sizes, got ", mat.strides(), " for strides and ", mat.sizes(), " for sizes");
TORCH_CHECK(false, "Tensor should have a contiguous dimension and not be self-overlapping, got ", mat.strides(), " for strides and ", mat.sizes(), " for sizes");
}
}
@ -1648,7 +1627,11 @@ bool use_fast_accum) {
check_scale(mat_a, scale_a, 0 ,0, scale_multiplier);
check_scale(mat_b, scale_b, 1, 1, scale_multiplier);
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype);
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
TORCH_CHECK(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
const auto out_size = compute_grouped_gemm_output_size(mat_a, mat_b, offs);
Tensor out = at::empty(out_size, mat_a.options().dtype(out_dtype_));
at::cuda::detail::f8f8bf16_grouped_mm(
mat_a,
@ -1684,7 +1667,6 @@ std::optional<c10::ScalarType> out_dtype) {
TORCH_CHECK(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
const bool a_is_2d = mat_a.dim() == 2;
const bool b_is_2d = mat_b.dim() == 2;
// check that the strides are valid, the fn will throw an error if not
check_valid_strides_and_return_transposed(mat_a);
check_valid_strides_and_return_transposed(mat_b);
@ -1694,10 +1676,12 @@ std::optional<c10::ScalarType> out_dtype) {
TORCH_CHECK(offs->dim() == 1, "offs has to be 1D");
TORCH_CHECK(offs->dtype() == at::kInt, "Offsets have to be int32");
}
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
TORCH_CHECK(out_dtype_ == kBFloat16, "Only bf16 high output type is supported for grouped gemm");
TORCH_CHECK(!bias.has_value(), "Bias not supported yet");
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype);
const auto out_size = compute_grouped_gemm_output_size(mat_a, mat_b, offs);
Tensor out = at::empty(out_size, mat_a.options().dtype(out_dtype_));
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
return out;
#else

View File

@ -47,42 +47,10 @@ __global__ void prepare_grouped_gemm_data(
if (offs != nullptr) {
int32_t start = tid == 0 ? 0 : offs[tid - 1];
delta = offs[tid] - start;
if (K < 0) {
// CUTLASS cannot handle delta=0 here.
CUDA_KERNEL_ASSERT(delta >0 && "expected ofsets to be greater than 0\n");
} else {
CUDA_KERNEL_ASSERT(delta >=0 && "expected ofsets to be greater or equal 0\n");
}
// TMA transfers require global memory tensor addresses to be
// aligned to 16 bytes.
if (tid < blockDim.x - 1) {
// Check this requirement for input tensors, in case group
// addresses are increased along the dynamic dimension.
if ((K < 0 && a_row_major) || // 2D/2D: check along K dimension
(M < 0 && !a_row_major)) { // 3D/2D: check along N dimension
int align = 128 / cutlass::sizeof_bits<DtypeA>::value;
CUDA_KERNEL_ASSERT(
delta % align == 0 &&
"expected input tensor dynamic dimension byte size to be non-negative multiple of 16\n");
}
if ((K < 0 && !b_row_major) || // 2D/2D: check along K dimension
(N < 0 && b_row_major)) { // 3D/2D: check along N dimension
int align = 128 / cutlass::sizeof_bits<DtypeB>::value;
CUDA_KERNEL_ASSERT(
delta % align == 0 &&
"expected input tensor dynamic dimension byte size to be non-negative multiple of 16\n");
}
// Check the same requirement for output tensor (that is always
// contiguous, and in row-major layout).
if (N < 0) {
int align = 128 / cutlass::sizeof_bits<DtypeOutput>::value;
CUDA_KERNEL_ASSERT(
delta % align == 0 &&
"expected output tensor dynamic dimension byte size to be non-negative multiple of 16\n");
}
}
int align = 16 / sizeof(DtypeA);
CUDA_KERNEL_ASSERT(
delta >=0 && delta % align == 0 &&
"expected dynamic dimension byte size to be non-negative multiple of 16 \n");
}
int64_t lda, ldb, ldoutput;
if (M < 0) {
@ -113,6 +81,7 @@ __global__ void prepare_grouped_gemm_data(
} else if (K < 0) {
// A, B is 2d, output is 3d
K = delta;
CUDA_KERNEL_ASSERT(delta > 0 && "can't handle K=0");
lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1];
ldb = b_row_major ? tensor_StrideB[0] : tensor_StrideB[1];
ldoutput = tensor_StrideOutput[1];

View File

@ -315,7 +315,7 @@ class TestMatmulCuda(TestCase):
def test_grouped_gemm_2d_2d(self, strided, a_row_major, b_row_major, use_torch_compile):
device = "cuda"
dtype = torch.bfloat16
m, n, k, n_groups = 16, 32, 64, 4
m, n, k, n_groups = 16, 32, 64, 4 # all sizes have to be divisible by 16
if a_row_major:
a = torch.randn(m, k * n_groups + k * int(strided), device=device, dtype=dtype)[:, :k * n_groups]
else:
@ -382,9 +382,6 @@ class TestMatmulCuda(TestCase):
b_contig = b if b_row_major else b.transpose(-2, -1)
self.assertTrue(b_contig.is_contiguous() is not strided)
for check_zero_size in (False, True):
if check_zero_size and n_groups <= 1:
continue
a.grad = None
b.grad = None
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
@ -487,9 +484,6 @@ class TestMatmulCuda(TestCase):
b_contig = b if b_row_major else b.transpose(-2, -1)
self.assertTrue(b_contig.is_contiguous() is not strided)
for check_zero_size in (False, True):
if check_zero_size and n_groups <= 1:
continue
offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32)
if check_zero_size:
offs[0] = offs[1]
@ -1651,27 +1645,17 @@ class TestFP8Matmul(TestCase):
for a, b, ascale, bscale, out in zip(alist, blist, ascalelist, bscalelist, outlist):
out_ref = torch._scaled_mm(a, b.t(), ascale.view(-1, 1), bscale.view(1, -1),
out_dtype=torch.bfloat16, use_fast_accum=use_fast_accum)
self.assertEqual(out, out_ref, atol=5e-2, rtol=5e-4)
# Testing only _scaled_grouped_mm() with multiple shapes, as
# _scaled_mm() already has more combinations of parameters than
# _scaled_grouped_mm(), for supporing more than one inputs layout
# combinations.
self.assertEqual(out, out_ref, atol=8e-2, rtol=8e-4)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@xfailIfSM100OrLater
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
@parametrize(
"n_groups, m, n, k",
[(2, 1, 16, 16),
(4, 16, 16, 16)],
name_fn=lambda n_groups, m, n, k: f"{n_groups}_{m}_{n}_{k}",
)
@parametrize("fast_accum", [False, True])
@parametrize("strided", [False, True])
@parametrize("use_torch_compile", [False, True])
def test_scaled_grouped_gemm_2d_2d(self, n_groups, m, n, k, fast_accum, strided, use_torch_compile):
def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided, use_torch_compile):
device = "cuda"
m, n, k, n_groups = 16, 32, 64, 4 # all sizes have to be divisible by 16
a = torch.randn(m, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups]
b = torch.randn(n, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups]
scale_a = torch.rand(m * n_groups, device=device, dtype=torch.float32)
@ -1701,26 +1685,18 @@ class TestFP8Matmul(TestCase):
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@xfailIfSM100OrLater
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
@parametrize(
"n_groups, m, n, k",
[(2, 1, 16, 16),
(4, 16, 16, 16)],
name_fn=lambda n_groups, m, n, k: f"{n_groups}_{m}_{n}_{k}",
)
@parametrize("fast_accum", [False, True])
@parametrize("strided", [False, True])
@parametrize("use_torch_compile", [False, True])
def test_scaled_grouped_gemm_2d_3d(self, n_groups, m, n, k, fast_accum, strided, use_torch_compile):
def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, use_torch_compile):
device = "cuda"
s_int = int(strided)
m, n, k, n_groups = 16, 32, 64, 4
a = torch.randn(m * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k]
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
self.assertTrue(a.is_contiguous() is not strided)
self.assertTrue(b.is_contiguous() is not strided)
for check_zero_size in (True, False):
if check_zero_size and n_groups <= 1:
continue
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
if check_zero_size:
offs[0] = offs[1]
@ -1751,18 +1727,13 @@ class TestFP8Matmul(TestCase):
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@xfailIfSM100OrLater
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
@parametrize(
"n_groups, m, n, k",
[(2, 1, 16, 16),
(4, 16, 16, 16)],
name_fn=lambda n_groups, m, n, k: f"{n_groups}_{m}_{n}_{k}",
)
@parametrize("fast_accum", [False, True])
@parametrize("strided", [False, True])
@parametrize("use_torch_compile", [False, True])
def test_scaled_grouped_gemm_3d_3d(self, n_groups, m, n, k, fast_accum, strided, use_torch_compile):
def test_scaled_grouped_gemm_3d_3d(self, fast_accum, strided, use_torch_compile):
device = "cuda"
s_int = int(strided)
m, n, k, n_groups = 16, 32, 64, 4
a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
self.assertTrue(a.is_contiguous() is not strided)
@ -1786,18 +1757,13 @@ class TestFP8Matmul(TestCase):
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@xfailIfSM100OrLater
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
@parametrize(
"n_groups, m, n, k",
[(2, 1, 16, 16),
(4, 16, 16, 16)],
name_fn=lambda n_groups, m, n, k: f"{n_groups}_{m}_{n}_{k}",
)
@parametrize("fast_accum", [False, True])
@parametrize("strided", [False, True])
@parametrize("use_torch_compile", [False, True])
def test_scaled_grouped_gemm_3d_2d(self, n_groups, m, n, k, fast_accum, strided, use_torch_compile):
def test_scaled_grouped_gemm_3d_2d(self, fast_accum, strided, use_torch_compile):
device = "cuda"
s_int = int(strided)
m, n, k, n_groups = 16, 32, 64, 4
a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
b = torch.randn(n * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k]
self.assertTrue(a.is_contiguous() is not strided)
@ -1805,9 +1771,6 @@ class TestFP8Matmul(TestCase):
scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m)
scale_b = torch.rand(n_groups * n, device="cuda", dtype=torch.float32)
for check_zero_size in (True, False):
if check_zero_size and n_groups <= 1:
continue
offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32)
if check_zero_size:
offs[0] = offs[1]

View File

@ -120,6 +120,7 @@ def early_config_prune(g, m, configs, named_args):
return pruned_configs
# Copied from fbgemm grouped_gemm.py
triton_grouped_mm_source = r"""
{%- if SCALED %}
{%- if A_IS_2D or B_IS_2D %}
@ -670,7 +671,7 @@ def _tuned_grouped_mm_common(
)
@register_lowering(aten._grouped_mm.default, type_promotion_kind=None)
@register_lowering(aten._grouped_mm, type_promotion_kind=None)
def tuned_grouped_mm(
mat_a: TensorBox,
mat_b: TensorBox,
@ -682,7 +683,7 @@ def tuned_grouped_mm(
"""Auto-tuning for _grouped_mm() operator."""
return _tuned_grouped_mm_common(
"aten._grouped_mm.default",
"aten._grouped_mm",
"grouped_mm",
aten__grouped_mm,
triton_grouped_mm_template,

View File

@ -7467,39 +7467,28 @@ def sigmoid(self: Tensor) -> Tensor:
return torch.empty_like(self, dtype=result_dtype)
def _create_grouped_mm_output_tensor(mat1, mat2, offs, out_dtype):
def _compute_grouped_mm_output_size(mat1, mat2, offs):
mat1_is_2d = mat1.dim() == 2
mat2_is_2d = mat2.dim() == 2
if mat1_is_2d:
if mat2_is_2d:
out_size = [offs.size(0), mat1.size(0), mat2.size(1)]
return offs.size(0), mat1.size(0), mat2.size(1)
else:
torch._check(
offs.size(0) == mat2.size(0), "matrix batch sizes have to match"
)
out_size = [mat1.size(0), mat2.size(-1)]
return mat1.size(0), mat2.size(-1)
else:
if mat2_is_2d:
torch._check(
offs.size(0) == mat1.size(0), "matrix batch sizes have to match"
)
out_size = [mat1.size(1), mat2.size(1)]
return mat1.size(1), mat2.size(1)
else:
# regular bmm
torch._check(mat1.size(0) == mat2.size(0), "batched dimension has to match")
out_size = [mat1.size(0), mat1.size(1), mat2.size(-1)]
out_dtype = out_dtype or mat1.dtype
alignment = 16 // out_dtype.itemsize
size_padded = (out_size[-1] + alignment - 1) // alignment * alignment
if mat1_is_2d == mat2_is_2d:
out_stride = [out_size[1] * size_padded, size_padded, 1]
else:
out_stride = [size_padded, 1]
out = torch.empty_strided(out_size, out_stride, dtype=out_dtype, device=mat1.device)
return out
return mat1.size(0), mat1.size(1), mat2.size(-1)
def _meta_grouped_mm_common(
@ -7542,6 +7531,15 @@ def _meta_grouped_mm_common(
mat_a_is_2d = mat_a.dim() == 2
mat_b_is_2d = mat_b.dim() == 2
torch._check(
mat_a.shape[-1] % 16 == 0,
lambda: f"Expected mat_a.shape[-1] to be divisible by 16, but got mat_a.shape[-1]={mat_a.shape[1]}",
)
torch._check(
mat_b.shape[-2] % 16 == 0 and mat_b.shape[-1] % 16 == 0,
lambda: f"Expected mat_b.shape[-2] and mat_b.shape[-1] to be both divisble by 16 but got {mat_b.shape[-2]} and {mat_b.shape[-1]}", # noqa: B950
)
if scaled:
def is_row_major(mat):
@ -7563,7 +7561,7 @@ def _meta_grouped_mm_common(
def check_valid_strides(mat_name, mat):
end_dim = mat.dim() - 1
alignment = 16 // mat.element_size()
alignment = 16 / mat.element_size()
mat_stride = mat.stride()
if mat_stride[end_dim - 1] == 1 and mat_stride[end_dim] >= max(
1, mat.shape[end_dim - 1]
@ -7582,7 +7580,7 @@ def _meta_grouped_mm_common(
else:
torch._check(
False,
lambda: f"Invalid strides/sizes, got {mat_stride} for strides and {mat.shape} for sizes.", # noqa: B950
lambda: f"Expected {mat_name} to have a contiguous dimension and not be mat_a-overlapping, got {mat_stride} for strides and {mat.shape} for sizes.", # noqa: B950
)
check_valid_strides("mat_a", mat_a)
@ -7667,7 +7665,9 @@ def _meta_grouped_mm_common(
lambda: "If output dtype provided, it must be torch.bfloat16.",
)
return _create_grouped_mm_output_tensor(mat_a, mat_b, offs, out_dtype)
out_size = _compute_grouped_mm_output_size(mat_a, mat_b, offs)
out_dtype = out_dtype or mat_a.dtype
return torch.empty(out_size, dtype=out_dtype, device=mat_a.device)
@register_meta(aten._grouped_mm)