Revert "Refine alignment check along dynamic dimension for grouped MMs (#155466)"

This reverts commit 830a335a7d.

Reverted https://github.com/pytorch/pytorch/pull/155466 on behalf of https://github.com/atalman due to breaks internal builds ([comment](https://github.com/pytorch/pytorch/pull/155466#issuecomment-2988285117))
This commit is contained in:
PyTorch MergeBot 2025-06-19 14:25:38 +00:00
parent fec8af8b98
commit 0b62465b99
5 changed files with 53 additions and 136 deletions

View File

@ -36,7 +36,6 @@
#include <ATen/ops/copy_native.h> #include <ATen/ops/copy_native.h>
#include <ATen/ops/dot_native.h> #include <ATen/ops/dot_native.h>
#include <ATen/ops/empty.h> #include <ATen/ops/empty.h>
#include <ATen/ops/empty_strided.h>
#include <ATen/ops/gelu.h> #include <ATen/ops/gelu.h>
#include <ATen/ops/max.h> #include <ATen/ops/max.h>
#include <ATen/ops/mm_native.h> #include <ATen/ops/mm_native.h>
@ -1482,49 +1481,29 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
} }
namespace { namespace {
at::Tensor create_grouped_gemm_output_tensor(const Tensor& mat_a, c10::SmallVector<int64_t, 3> compute_grouped_gemm_output_size(const Tensor& mat_a,
const Tensor& mat_b, const Tensor& mat_b,
const std::optional<at::Tensor>& offs, const std::optional<at::Tensor>& offs
std::optional<c10::ScalarType> out_dtype
) { ) {
c10::SmallVector<int64_t, 3> out_size;
const bool a_is_2d = mat_a.dim() == 2; const bool a_is_2d = mat_a.dim() == 2;
const bool b_is_2d = mat_b.dim() == 2; const bool b_is_2d = mat_b.dim() == 2;
if (a_is_2d) { if (a_is_2d) {
if (b_is_2d) { if (b_is_2d) {
out_size = {offs->size(0), mat_a.size(0), mat_b.size(1)}; return {offs->size(0), mat_a.size(0), mat_b.size(1)};
} else { } else {
TORCH_CHECK(offs->size(0) == mat_b.size(0), "matrix batch sizes have to match"); TORCH_CHECK(offs->size(0) == mat_b.size(0), "matrix batch sizes have to match");
out_size = {mat_a.size(0), mat_b.size(-1)}; return {mat_a.size(0), mat_b.size(-1)};
} }
} else { } else {
if (b_is_2d) { if (b_is_2d) {
// this case is not actually encountered for MoE gemms // this case is not actually encountered for MoE gemms
TORCH_CHECK(offs->size(0) == mat_a.size(0), "matrix batch sizes have to match"); TORCH_CHECK(offs->size(0) == mat_a.size(0), "matrix batch sizes have to match");
out_size = {mat_a.size(1), mat_b.size(1)}; return {mat_a.size(1), mat_b.size(1)};
} else { // regular bmm } else { // regular bmm
TORCH_CHECK(mat_a.size(0) == mat_b.size(0), "batched dimension has to match"); TORCH_CHECK(mat_a.size(0) == mat_b.size(0), "batched dimension has to match");
out_size = {mat_a.size(0), mat_a.size(1), mat_b.size(-1)}; return {mat_a.size(0), mat_a.size(1), mat_b.size(-1)};
} }
} }
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
TORCH_CHECK(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
// For TMA transfers, strides of output tensor have to be either
// 1, or aligned to 16 bytes.
const auto last_dim = out_size.size() - 1;
const auto alignment = 16 / c10::elementSize(out_dtype_);
const int64_t size_padded = (out_size[last_dim] + alignment - 1) / alignment * alignment;
std::vector<int64_t> out_stride;
if (a_is_2d != b_is_2d) {
out_stride = {size_padded, 1};
} else {
out_stride = {out_size[1] * size_padded, size_padded, 1};
}
auto out = at::empty_strided(out_size, out_stride, mat_a.options().dtype(out_dtype_));
return out;
} }
bool check_valid_strides_and_return_transposed(const Tensor& mat) { bool check_valid_strides_and_return_transposed(const Tensor& mat) {
@ -1540,7 +1519,7 @@ namespace {
TORCH_CHECK(tensor_strides[end_dim - 1] % alignment == 0, "strides should be multiple of 16 bytes"); TORCH_CHECK(tensor_strides[end_dim - 1] % alignment == 0, "strides should be multiple of 16 bytes");
return false; return false;
} else { } else {
TORCH_CHECK(false, "Invalid strides/sizes, got ", mat.strides(), " for strides and ", mat.sizes(), " for sizes"); TORCH_CHECK(false, "Tensor should have a contiguous dimension and not be self-overlapping, got ", mat.strides(), " for strides and ", mat.sizes(), " for sizes");
} }
} }
@ -1648,7 +1627,11 @@ bool use_fast_accum) {
check_scale(mat_a, scale_a, 0 ,0, scale_multiplier); check_scale(mat_a, scale_a, 0 ,0, scale_multiplier);
check_scale(mat_b, scale_b, 1, 1, scale_multiplier); check_scale(mat_b, scale_b, 1, 1, scale_multiplier);
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype); const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
TORCH_CHECK(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
const auto out_size = compute_grouped_gemm_output_size(mat_a, mat_b, offs);
Tensor out = at::empty(out_size, mat_a.options().dtype(out_dtype_));
at::cuda::detail::f8f8bf16_grouped_mm( at::cuda::detail::f8f8bf16_grouped_mm(
mat_a, mat_a,
@ -1684,7 +1667,6 @@ std::optional<c10::ScalarType> out_dtype) {
TORCH_CHECK(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d"); TORCH_CHECK(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
const bool a_is_2d = mat_a.dim() == 2; const bool a_is_2d = mat_a.dim() == 2;
const bool b_is_2d = mat_b.dim() == 2; const bool b_is_2d = mat_b.dim() == 2;
// check that the strides are valid, the fn will throw an error if not // check that the strides are valid, the fn will throw an error if not
check_valid_strides_and_return_transposed(mat_a); check_valid_strides_and_return_transposed(mat_a);
check_valid_strides_and_return_transposed(mat_b); check_valid_strides_and_return_transposed(mat_b);
@ -1694,10 +1676,12 @@ std::optional<c10::ScalarType> out_dtype) {
TORCH_CHECK(offs->dim() == 1, "offs has to be 1D"); TORCH_CHECK(offs->dim() == 1, "offs has to be 1D");
TORCH_CHECK(offs->dtype() == at::kInt, "Offsets have to be int32"); TORCH_CHECK(offs->dtype() == at::kInt, "Offsets have to be int32");
} }
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
TORCH_CHECK(out_dtype_ == kBFloat16, "Only bf16 high output type is supported for grouped gemm");
TORCH_CHECK(!bias.has_value(), "Bias not supported yet"); TORCH_CHECK(!bias.has_value(), "Bias not supported yet");
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype); const auto out_size = compute_grouped_gemm_output_size(mat_a, mat_b, offs);
Tensor out = at::empty(out_size, mat_a.options().dtype(out_dtype_));
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out); at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
return out; return out;
#else #else

View File

@ -47,42 +47,10 @@ __global__ void prepare_grouped_gemm_data(
if (offs != nullptr) { if (offs != nullptr) {
int32_t start = tid == 0 ? 0 : offs[tid - 1]; int32_t start = tid == 0 ? 0 : offs[tid - 1];
delta = offs[tid] - start; delta = offs[tid] - start;
if (K < 0) { int align = 16 / sizeof(DtypeA);
// CUTLASS cannot handle delta=0 here. CUDA_KERNEL_ASSERT(
CUDA_KERNEL_ASSERT(delta >0 && "expected ofsets to be greater than 0\n"); delta >=0 && delta % align == 0 &&
} else { "expected dynamic dimension byte size to be non-negative multiple of 16 \n");
CUDA_KERNEL_ASSERT(delta >=0 && "expected ofsets to be greater or equal 0\n");
}
// TMA transfers require global memory tensor addresses to be
// aligned to 16 bytes.
if (tid < blockDim.x - 1) {
// Check this requirement for input tensors, in case group
// addresses are increased along the dynamic dimension.
if ((K < 0 && a_row_major) || // 2D/2D: check along K dimension
(M < 0 && !a_row_major)) { // 3D/2D: check along N dimension
int align = 128 / cutlass::sizeof_bits<DtypeA>::value;
CUDA_KERNEL_ASSERT(
delta % align == 0 &&
"expected input tensor dynamic dimension byte size to be non-negative multiple of 16\n");
}
if ((K < 0 && !b_row_major) || // 2D/2D: check along K dimension
(N < 0 && b_row_major)) { // 3D/2D: check along N dimension
int align = 128 / cutlass::sizeof_bits<DtypeB>::value;
CUDA_KERNEL_ASSERT(
delta % align == 0 &&
"expected input tensor dynamic dimension byte size to be non-negative multiple of 16\n");
}
// Check the same requirement for output tensor (that is always
// contiguous, and in row-major layout).
if (N < 0) {
int align = 128 / cutlass::sizeof_bits<DtypeOutput>::value;
CUDA_KERNEL_ASSERT(
delta % align == 0 &&
"expected output tensor dynamic dimension byte size to be non-negative multiple of 16\n");
}
}
} }
int64_t lda, ldb, ldoutput; int64_t lda, ldb, ldoutput;
if (M < 0) { if (M < 0) {
@ -113,6 +81,7 @@ __global__ void prepare_grouped_gemm_data(
} else if (K < 0) { } else if (K < 0) {
// A, B is 2d, output is 3d // A, B is 2d, output is 3d
K = delta; K = delta;
CUDA_KERNEL_ASSERT(delta > 0 && "can't handle K=0");
lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1]; lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1];
ldb = b_row_major ? tensor_StrideB[0] : tensor_StrideB[1]; ldb = b_row_major ? tensor_StrideB[0] : tensor_StrideB[1];
ldoutput = tensor_StrideOutput[1]; ldoutput = tensor_StrideOutput[1];

View File

@ -315,7 +315,7 @@ class TestMatmulCuda(TestCase):
def test_grouped_gemm_2d_2d(self, strided, a_row_major, b_row_major, use_torch_compile): def test_grouped_gemm_2d_2d(self, strided, a_row_major, b_row_major, use_torch_compile):
device = "cuda" device = "cuda"
dtype = torch.bfloat16 dtype = torch.bfloat16
m, n, k, n_groups = 16, 32, 64, 4 m, n, k, n_groups = 16, 32, 64, 4 # all sizes have to be divisible by 16
if a_row_major: if a_row_major:
a = torch.randn(m, k * n_groups + k * int(strided), device=device, dtype=dtype)[:, :k * n_groups] a = torch.randn(m, k * n_groups + k * int(strided), device=device, dtype=dtype)[:, :k * n_groups]
else: else:
@ -382,9 +382,6 @@ class TestMatmulCuda(TestCase):
b_contig = b if b_row_major else b.transpose(-2, -1) b_contig = b if b_row_major else b.transpose(-2, -1)
self.assertTrue(b_contig.is_contiguous() is not strided) self.assertTrue(b_contig.is_contiguous() is not strided)
for check_zero_size in (False, True): for check_zero_size in (False, True):
if check_zero_size and n_groups <= 1:
continue
a.grad = None a.grad = None
b.grad = None b.grad = None
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32) offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
@ -487,9 +484,6 @@ class TestMatmulCuda(TestCase):
b_contig = b if b_row_major else b.transpose(-2, -1) b_contig = b if b_row_major else b.transpose(-2, -1)
self.assertTrue(b_contig.is_contiguous() is not strided) self.assertTrue(b_contig.is_contiguous() is not strided)
for check_zero_size in (False, True): for check_zero_size in (False, True):
if check_zero_size and n_groups <= 1:
continue
offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32) offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32)
if check_zero_size: if check_zero_size:
offs[0] = offs[1] offs[0] = offs[1]
@ -1651,27 +1645,17 @@ class TestFP8Matmul(TestCase):
for a, b, ascale, bscale, out in zip(alist, blist, ascalelist, bscalelist, outlist): for a, b, ascale, bscale, out in zip(alist, blist, ascalelist, bscalelist, outlist):
out_ref = torch._scaled_mm(a, b.t(), ascale.view(-1, 1), bscale.view(1, -1), out_ref = torch._scaled_mm(a, b.t(), ascale.view(-1, 1), bscale.view(1, -1),
out_dtype=torch.bfloat16, use_fast_accum=use_fast_accum) out_dtype=torch.bfloat16, use_fast_accum=use_fast_accum)
self.assertEqual(out, out_ref, atol=5e-2, rtol=5e-4) self.assertEqual(out, out_ref, atol=8e-2, rtol=8e-4)
# Testing only _scaled_grouped_mm() with multiple shapes, as
# _scaled_mm() already has more combinations of parameters than
# _scaled_grouped_mm(), for supporing more than one inputs layout
# combinations.
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@xfailIfSM100OrLater @xfailIfSM100OrLater
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
@parametrize(
"n_groups, m, n, k",
[(2, 1, 16, 16),
(4, 16, 16, 16)],
name_fn=lambda n_groups, m, n, k: f"{n_groups}_{m}_{n}_{k}",
)
@parametrize("fast_accum", [False, True]) @parametrize("fast_accum", [False, True])
@parametrize("strided", [False, True]) @parametrize("strided", [False, True])
@parametrize("use_torch_compile", [False, True]) @parametrize("use_torch_compile", [False, True])
def test_scaled_grouped_gemm_2d_2d(self, n_groups, m, n, k, fast_accum, strided, use_torch_compile): def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided, use_torch_compile):
device = "cuda" device = "cuda"
m, n, k, n_groups = 16, 32, 64, 4 # all sizes have to be divisible by 16
a = torch.randn(m, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups] a = torch.randn(m, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups]
b = torch.randn(n, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups] b = torch.randn(n, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups]
scale_a = torch.rand(m * n_groups, device=device, dtype=torch.float32) scale_a = torch.rand(m * n_groups, device=device, dtype=torch.float32)
@ -1701,26 +1685,18 @@ class TestFP8Matmul(TestCase):
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@xfailIfSM100OrLater @xfailIfSM100OrLater
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
@parametrize(
"n_groups, m, n, k",
[(2, 1, 16, 16),
(4, 16, 16, 16)],
name_fn=lambda n_groups, m, n, k: f"{n_groups}_{m}_{n}_{k}",
)
@parametrize("fast_accum", [False, True]) @parametrize("fast_accum", [False, True])
@parametrize("strided", [False, True]) @parametrize("strided", [False, True])
@parametrize("use_torch_compile", [False, True]) @parametrize("use_torch_compile", [False, True])
def test_scaled_grouped_gemm_2d_3d(self, n_groups, m, n, k, fast_accum, strided, use_torch_compile): def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, use_torch_compile):
device = "cuda" device = "cuda"
s_int = int(strided) s_int = int(strided)
m, n, k, n_groups = 16, 32, 64, 4
a = torch.randn(m * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k] a = torch.randn(m * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k]
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k] b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
self.assertTrue(a.is_contiguous() is not strided) self.assertTrue(a.is_contiguous() is not strided)
self.assertTrue(b.is_contiguous() is not strided) self.assertTrue(b.is_contiguous() is not strided)
for check_zero_size in (True, False): for check_zero_size in (True, False):
if check_zero_size and n_groups <= 1:
continue
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32) offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
if check_zero_size: if check_zero_size:
offs[0] = offs[1] offs[0] = offs[1]
@ -1751,18 +1727,13 @@ class TestFP8Matmul(TestCase):
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@xfailIfSM100OrLater @xfailIfSM100OrLater
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
@parametrize(
"n_groups, m, n, k",
[(2, 1, 16, 16),
(4, 16, 16, 16)],
name_fn=lambda n_groups, m, n, k: f"{n_groups}_{m}_{n}_{k}",
)
@parametrize("fast_accum", [False, True]) @parametrize("fast_accum", [False, True])
@parametrize("strided", [False, True]) @parametrize("strided", [False, True])
@parametrize("use_torch_compile", [False, True]) @parametrize("use_torch_compile", [False, True])
def test_scaled_grouped_gemm_3d_3d(self, n_groups, m, n, k, fast_accum, strided, use_torch_compile): def test_scaled_grouped_gemm_3d_3d(self, fast_accum, strided, use_torch_compile):
device = "cuda" device = "cuda"
s_int = int(strided) s_int = int(strided)
m, n, k, n_groups = 16, 32, 64, 4
a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k] a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k] b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
self.assertTrue(a.is_contiguous() is not strided) self.assertTrue(a.is_contiguous() is not strided)
@ -1786,18 +1757,13 @@ class TestFP8Matmul(TestCase):
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@xfailIfSM100OrLater @xfailIfSM100OrLater
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
@parametrize(
"n_groups, m, n, k",
[(2, 1, 16, 16),
(4, 16, 16, 16)],
name_fn=lambda n_groups, m, n, k: f"{n_groups}_{m}_{n}_{k}",
)
@parametrize("fast_accum", [False, True]) @parametrize("fast_accum", [False, True])
@parametrize("strided", [False, True]) @parametrize("strided", [False, True])
@parametrize("use_torch_compile", [False, True]) @parametrize("use_torch_compile", [False, True])
def test_scaled_grouped_gemm_3d_2d(self, n_groups, m, n, k, fast_accum, strided, use_torch_compile): def test_scaled_grouped_gemm_3d_2d(self, fast_accum, strided, use_torch_compile):
device = "cuda" device = "cuda"
s_int = int(strided) s_int = int(strided)
m, n, k, n_groups = 16, 32, 64, 4
a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k] a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
b = torch.randn(n * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k] b = torch.randn(n * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k]
self.assertTrue(a.is_contiguous() is not strided) self.assertTrue(a.is_contiguous() is not strided)
@ -1805,9 +1771,6 @@ class TestFP8Matmul(TestCase):
scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m) scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m)
scale_b = torch.rand(n_groups * n, device="cuda", dtype=torch.float32) scale_b = torch.rand(n_groups * n, device="cuda", dtype=torch.float32)
for check_zero_size in (True, False): for check_zero_size in (True, False):
if check_zero_size and n_groups <= 1:
continue
offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32) offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32)
if check_zero_size: if check_zero_size:
offs[0] = offs[1] offs[0] = offs[1]

View File

@ -120,6 +120,7 @@ def early_config_prune(g, m, configs, named_args):
return pruned_configs return pruned_configs
# Copied from fbgemm grouped_gemm.py
triton_grouped_mm_source = r""" triton_grouped_mm_source = r"""
{%- if SCALED %} {%- if SCALED %}
{%- if A_IS_2D or B_IS_2D %} {%- if A_IS_2D or B_IS_2D %}
@ -670,7 +671,7 @@ def _tuned_grouped_mm_common(
) )
@register_lowering(aten._grouped_mm.default, type_promotion_kind=None) @register_lowering(aten._grouped_mm, type_promotion_kind=None)
def tuned_grouped_mm( def tuned_grouped_mm(
mat_a: TensorBox, mat_a: TensorBox,
mat_b: TensorBox, mat_b: TensorBox,
@ -682,7 +683,7 @@ def tuned_grouped_mm(
"""Auto-tuning for _grouped_mm() operator.""" """Auto-tuning for _grouped_mm() operator."""
return _tuned_grouped_mm_common( return _tuned_grouped_mm_common(
"aten._grouped_mm.default", "aten._grouped_mm",
"grouped_mm", "grouped_mm",
aten__grouped_mm, aten__grouped_mm,
triton_grouped_mm_template, triton_grouped_mm_template,

View File

@ -7467,39 +7467,28 @@ def sigmoid(self: Tensor) -> Tensor:
return torch.empty_like(self, dtype=result_dtype) return torch.empty_like(self, dtype=result_dtype)
def _create_grouped_mm_output_tensor(mat1, mat2, offs, out_dtype): def _compute_grouped_mm_output_size(mat1, mat2, offs):
mat1_is_2d = mat1.dim() == 2 mat1_is_2d = mat1.dim() == 2
mat2_is_2d = mat2.dim() == 2 mat2_is_2d = mat2.dim() == 2
if mat1_is_2d: if mat1_is_2d:
if mat2_is_2d: if mat2_is_2d:
out_size = [offs.size(0), mat1.size(0), mat2.size(1)] return offs.size(0), mat1.size(0), mat2.size(1)
else: else:
torch._check( torch._check(
offs.size(0) == mat2.size(0), "matrix batch sizes have to match" offs.size(0) == mat2.size(0), "matrix batch sizes have to match"
) )
out_size = [mat1.size(0), mat2.size(-1)] return mat1.size(0), mat2.size(-1)
else: else:
if mat2_is_2d: if mat2_is_2d:
torch._check( torch._check(
offs.size(0) == mat1.size(0), "matrix batch sizes have to match" offs.size(0) == mat1.size(0), "matrix batch sizes have to match"
) )
out_size = [mat1.size(1), mat2.size(1)] return mat1.size(1), mat2.size(1)
else: else:
# regular bmm # regular bmm
torch._check(mat1.size(0) == mat2.size(0), "batched dimension has to match") torch._check(mat1.size(0) == mat2.size(0), "batched dimension has to match")
out_size = [mat1.size(0), mat1.size(1), mat2.size(-1)] return mat1.size(0), mat1.size(1), mat2.size(-1)
out_dtype = out_dtype or mat1.dtype
alignment = 16 // out_dtype.itemsize
size_padded = (out_size[-1] + alignment - 1) // alignment * alignment
if mat1_is_2d == mat2_is_2d:
out_stride = [out_size[1] * size_padded, size_padded, 1]
else:
out_stride = [size_padded, 1]
out = torch.empty_strided(out_size, out_stride, dtype=out_dtype, device=mat1.device)
return out
def _meta_grouped_mm_common( def _meta_grouped_mm_common(
@ -7542,6 +7531,15 @@ def _meta_grouped_mm_common(
mat_a_is_2d = mat_a.dim() == 2 mat_a_is_2d = mat_a.dim() == 2
mat_b_is_2d = mat_b.dim() == 2 mat_b_is_2d = mat_b.dim() == 2
torch._check(
mat_a.shape[-1] % 16 == 0,
lambda: f"Expected mat_a.shape[-1] to be divisible by 16, but got mat_a.shape[-1]={mat_a.shape[1]}",
)
torch._check(
mat_b.shape[-2] % 16 == 0 and mat_b.shape[-1] % 16 == 0,
lambda: f"Expected mat_b.shape[-2] and mat_b.shape[-1] to be both divisble by 16 but got {mat_b.shape[-2]} and {mat_b.shape[-1]}", # noqa: B950
)
if scaled: if scaled:
def is_row_major(mat): def is_row_major(mat):
@ -7563,7 +7561,7 @@ def _meta_grouped_mm_common(
def check_valid_strides(mat_name, mat): def check_valid_strides(mat_name, mat):
end_dim = mat.dim() - 1 end_dim = mat.dim() - 1
alignment = 16 // mat.element_size() alignment = 16 / mat.element_size()
mat_stride = mat.stride() mat_stride = mat.stride()
if mat_stride[end_dim - 1] == 1 and mat_stride[end_dim] >= max( if mat_stride[end_dim - 1] == 1 and mat_stride[end_dim] >= max(
1, mat.shape[end_dim - 1] 1, mat.shape[end_dim - 1]
@ -7582,7 +7580,7 @@ def _meta_grouped_mm_common(
else: else:
torch._check( torch._check(
False, False,
lambda: f"Invalid strides/sizes, got {mat_stride} for strides and {mat.shape} for sizes.", # noqa: B950 lambda: f"Expected {mat_name} to have a contiguous dimension and not be mat_a-overlapping, got {mat_stride} for strides and {mat.shape} for sizes.", # noqa: B950
) )
check_valid_strides("mat_a", mat_a) check_valid_strides("mat_a", mat_a)
@ -7667,7 +7665,9 @@ def _meta_grouped_mm_common(
lambda: "If output dtype provided, it must be torch.bfloat16.", lambda: "If output dtype provided, it must be torch.bfloat16.",
) )
return _create_grouped_mm_output_tensor(mat_a, mat_b, offs, out_dtype) out_size = _compute_grouped_mm_output_size(mat_a, mat_b, offs)
out_dtype = out_dtype or mat_a.dtype
return torch.empty(out_size, dtype=out_dtype, device=mat_a.device)
@register_meta(aten._grouped_mm) @register_meta(aten._grouped_mm)