error out on negative offs or on K=0 in group gemm (#153226)

Error out if K=0 in one of the grouped gemms to avoid hangs in #152668
Also, adds meta function for _scaled_grouped_mm (TODO: do the same for _grouped_mm, unless it's done already)

One weird thing I'm seeing, when running all grouped_gemm tests, I'm erroring out with
```
  File "/data/users/ngimel/pytorch/torch/_inductor/graph.py", line 1246, in call_function
    out = lowerings[target](*args, **kwargs)  # type: ignore[index]
  File "/data/users/ngimel/pytorch/torch/_inductor/lowering.py", line 445, in wrapped
    out = decomp_fn(*args, **kwargs)
  File "/data/users/ngimel/pytorch/torch/_inductor/kernel/mm_scaled_grouped.py", line 444, in tuned_scaled_grouped_mm
    if is_nonzero and can_use_triton_kernel(mat_a, mat_b, offs, bias):
  File "/data/users/ngimel/pytorch/torch/_inductor/kernel/mm_scaled_grouped.py", line 375, in can_use_triton_kernel
    offs is not None
  File "/home/ngimel/.conda/envs/pytorch_monarch/lib/python3.10/site-packages/sympy/core/relational.py", line 516, in __bool__
    raise TypeError("cannot determine truth value of Relational")
torch._inductor.exc.InductorError: LoweringException: TypeError: cannot determine truth value of Relational
```
which is weird, there's no relational that sympy has to evaluate in `offs is not None`, and when running this test separately (`test_scaled_grouped_gemm_2d_3d_fast_accum_True_strided_False_use_torch_compile_True_cuda`) it passes. I suspect some autotuning cache has to be reset between runs, but don't know what to look for.
Edit: that error is "fixed" by setting `dynamic=False`, now with correct meat function something's wrong with dynamic shapes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153226
Approved by: https://github.com/kwen2501
This commit is contained in:
Natalia Gimelshein 2025-05-10 01:13:18 +00:00 committed by PyTorch MergeBot
parent 639793c17e
commit 9c99ea2991
3 changed files with 235 additions and 91 deletions

View File

@ -49,8 +49,8 @@ __global__ void prepare_grouped_gemm_data(
delta = offs[tid] - start;
int align = 16 / sizeof(DtypeA);
CUDA_KERNEL_ASSERT(
delta % align == 0 &&
"expected dynamic dimension byte size to be multiple of 16 \n");
delta >=0 && delta % align == 0 &&
"expected dynamic dimension byte size to be non-negative multiple of 16 \n");
}
int64_t lda, ldb, ldoutput;
if (M < 0) {
@ -81,6 +81,7 @@ __global__ void prepare_grouped_gemm_data(
} else if (K < 0) {
// A, B is 2d, output is 3d
K = delta;
CUDA_KERNEL_ASSERT(delta > 0 && "can't handle K=0");
lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1];
ldb = b_row_major ? tensor_StrideB[0] : tensor_StrideB[1];
ldoutput = tensor_StrideOutput[1];

View File

@ -268,8 +268,9 @@ class TestMatmulCuda(TestCase):
out_ref = torch.mm(a, b.t())
out_ref.backward(gO)
self.assertEqual(out, out_ref)
self.assertEqual(agrad, a.grad)
self.assertEqual(bgrad, b.grad)
if agrad is not None:
self.assertEqual(agrad, a.grad)
self.assertEqual(bgrad, b.grad)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@xfailIfSM100OrLater
@ -338,22 +339,28 @@ class TestMatmulCuda(TestCase):
self.assertTrue(a_contig.is_contiguous() is not strided)
b_contig = b if b_row_major else b.transpose(-2, -1)
self.assertTrue(b_contig.is_contiguous() is not strided)
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
out = torch._grouped_mm(a, b.transpose(-2, -1), offs=offs,
out_dtype=torch.bfloat16)
gO = torch.rand_like(out)
out.backward(gO)
offs_cpu = offs.cpu()
alist, agradlist, gOlist, outlist = [], [], [], []
start = 0
for i in range(n_groups):
alist.append(a[start:offs_cpu[i]])
agradlist.append(a.grad[start:offs_cpu[i]])
outlist.append(out[start:offs_cpu[i]])
gOlist.append(gO[start:offs_cpu[i]])
start = offs_cpu[i]
self.grouped_mm_helper(alist, b, gOlist, agradlist, b.grad, outlist)
for check_zero_size in (False, True):
a.grad = None
b.grad = None
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
if check_zero_size:
offs[0] = offs[1]
out = torch._grouped_mm(a, b.transpose(-2, -1), offs=offs,
out_dtype=torch.bfloat16)
gO = torch.rand_like(out)
if not check_zero_size:
out.backward(gO)
offs_cpu = offs.cpu()
alist, agradlist, gOlist, outlist = [], [], [], []
bgradlist = [None] * n_groups if check_zero_size else b.grad
start = 0
for i in range(n_groups):
alist.append(a[start:offs_cpu[i]])
agradlist.append(None if check_zero_size else a.grad[start:offs_cpu[i]])
outlist.append(out[start:offs_cpu[i]])
gOlist.append(gO[start:offs_cpu[i]])
start = offs_cpu[i]
self.grouped_mm_helper(alist, b, gOlist, agradlist, bgradlist, outlist)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@ -418,21 +425,26 @@ class TestMatmulCuda(TestCase):
self.assertTrue(a_contig.is_contiguous() is not strided)
b_contig = b if b_row_major else b.transpose(-2, -1)
self.assertTrue(b_contig.is_contiguous() is not strided)
offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32)
out = torch._grouped_mm(a, b.transpose(-2, -1), offs=offs,
out_dtype=torch.bfloat16)
gO = torch.rand_like(out)
out.backward(gO)
offs_cpu = offs.cpu()
blist, outlist, bgradlist, gOlist = [], [], [], []
start = 0
for i in range(n_groups):
blist.append(b[start:offs_cpu[i]])
bgradlist.append(b.grad[start:offs_cpu[i]])
outlist.append(out[:, start:offs_cpu[i]])
gOlist.append(gO[:, start:offs_cpu[i]])
start = offs_cpu[i]
self.grouped_mm_helper(a, blist, gOlist, a.grad, bgradlist, outlist)
for check_zero_size in (False, True):
offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32)
if check_zero_size:
offs[0] = offs[1]
out = torch._grouped_mm(a, b.transpose(-2, -1), offs=offs,
out_dtype=torch.bfloat16)
gO = torch.rand_like(out)
if not check_zero_size:
out.backward(gO)
offs_cpu = offs.cpu()
blist, outlist, bgradlist, gOlist = [], [], [], []
agradlist = [None] * n_groups if check_zero_size else a.grad
start = 0
for i in range(n_groups):
blist.append(b[start:offs_cpu[i]])
bgradlist.append(b.grad[start:offs_cpu[i]])
outlist.append(out[:, start:offs_cpu[i]])
gOlist.append(gO[:, start:offs_cpu[i]])
start = offs_cpu[i]
self.grouped_mm_helper(a, blist, gOlist, agradlist, bgradlist, outlist)
@onlyCUDA
@ -1611,24 +1623,27 @@ class TestFP8Matmul(TestCase):
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
self.assertTrue(a.is_contiguous() is not strided)
self.assertTrue(b.is_contiguous() is not strided)
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
scale_a = torch.arange(n_groups * m, device="cuda", dtype=torch.float32)
scale_b = torch.ones(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)
for check_zero_size in (True, False):
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
if check_zero_size:
offs[0] = offs[1]
scale_a = torch.arange(n_groups * m, device="cuda", dtype=torch.float32)
scale_b = torch.ones(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)
f = torch._scaled_grouped_mm
f = torch.compile(f) if use_torch_compile else f
out = f(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs,
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
f = torch._scaled_grouped_mm
f = torch.compile(f, dynamic=False) if use_torch_compile else f
out = f(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs,
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
offs_cpu = offs.cpu()
alist, ascalelist, outlist = [], [], []
start = 0
for i in range(n_groups):
alist.append(a[start:offs_cpu[i]])
ascalelist.append(scale_a[start:offs_cpu[i]])
outlist.append(out[start:offs_cpu[i]])
start = offs_cpu[i]
self.scaled_grouped_mm_helper(alist, b, ascalelist, scale_b, outlist, fast_accum)
offs_cpu = offs.cpu()
alist, ascalelist, outlist = [], [], []
start = 0
for i in range(n_groups):
alist.append(a[start:offs_cpu[i]])
ascalelist.append(scale_a[start:offs_cpu[i]])
outlist.append(out[start:offs_cpu[i]])
start = offs_cpu[i]
self.scaled_grouped_mm_helper(alist, b, ascalelist, scale_b, outlist, fast_accum)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@ -1672,21 +1687,24 @@ class TestFP8Matmul(TestCase):
self.assertTrue(b.is_contiguous() is not strided)
scale_a = torch.arange(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m)
scale_b = torch.arange(n_groups * n, device="cuda", dtype=torch.float32)
offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32)
for check_zero_size in (True, False):
offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32)
if check_zero_size:
offs[0] = offs[1]
f = torch._scaled_grouped_mm
f = torch.compile(f) if use_torch_compile else f
out = f(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs,
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
offs_cpu = offs.cpu()
blist, bscalelist, outlist = [], [], []
start = 0
for i in range(n_groups):
blist.append(b[start:offs_cpu[i]])
bscalelist.append(scale_b[start:offs_cpu[i]])
outlist.append(out[:, start:offs_cpu[i]])
start = offs_cpu[i]
self.scaled_grouped_mm_helper(a, blist, scale_a, bscalelist, outlist, fast_accum)
f = torch._scaled_grouped_mm
f = torch.compile(f) if use_torch_compile else f
out = f(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs,
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
offs_cpu = offs.cpu()
blist, bscalelist, outlist = [], [], []
start = 0
for i in range(n_groups):
blist.append(b[start:offs_cpu[i]])
bscalelist.append(scale_b[start:offs_cpu[i]])
outlist.append(out[:, start:offs_cpu[i]])
start = offs_cpu[i]
self.scaled_grouped_mm_helper(a, blist, scale_a, bscalelist, outlist, fast_accum)
@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg)
@ -1869,7 +1887,7 @@ class TestMixedDtypesLinearCuda(TestCase):
instantiate_device_type_tests(TestMatmulCuda, globals(), except_for="cpu")
instantiate_device_type_tests(TestMixedDtypesLinearCuda, globals(), except_for="cpu")
instantiate_device_type_tests(TestFP8Matmul, globals())
instantiate_device_type_tests(TestFP8Matmul, globals(), except_for="cpu")
if __name__ == '__main__':
TestCase._default_dtype_check_enabled = True

View File

@ -7272,6 +7272,30 @@ def sigmoid(self: Tensor) -> Tensor:
return torch.empty_like(self, dtype=result_dtype)
def _compute_grouped_gemm_output_size(mat1, mat2, offs):
mat1_is_2d = mat1.dim() == 2
mat2_is_2d = mat2.dim() == 2
if mat1_is_2d:
if mat2_is_2d:
return offs.size(0), mat1.size(0), mat2.size(1)
else:
torch._check(
offs.size(0) == mat2.size(0), "matrix batch sizes have to match"
)
return mat1.size(0), mat2.size(-1)
else:
if mat2_is_2d:
torch._check(
offs.size(0) == mat1.size(0), "matrix batch sizes have to match"
)
return mat1.size(1), mat2.size(1)
else:
# regular bmm
torch._check(mat1.size(0) == mat2.size(0), "batched dimension has to match")
return mat1.size(0), mat1.size(1), mat2.size(-1)
@register_meta(aten._grouped_mm)
@out_wrapper()
def grouped_mm(
@ -7294,37 +7318,138 @@ def grouped_mm(
out_dtype = out_dtype or mat1.dtype
torch._check(bias is None, lambda: "bias not supported yet")
def _compute_grouped_gemm_output_size(mat1, mat2, offs):
mat1_is_2d = mat1.dim() == 2
mat2_is_2d = mat2.dim() == 2
if mat1_is_2d:
if mat2_is_2d:
return offs.size(0), mat1.size(0), mat2.size(1)
else:
torch._check(
offs.size(0) == mat2.size(0), "matrix batch sizes have to match"
)
return mat1.size(0), mat2.size(-1)
else:
if mat2_is_2d:
torch._check(
offs.size(0) == mat1.size(0), "matrix batch sizes have to match"
)
return mat1.size(1), mat2.size(1)
else:
# regular bmm
torch._check(
mat1.size(0) == mat2.size(0), "batched dimension has to match"
)
return mat1.size(0), mat1.size(1), mat2.size(-1)
out_size = _compute_grouped_gemm_output_size(mat1, mat2, offs)
out = mat1.new_empty(out_size, dtype=out_dtype)
return out
@register_meta([aten._scaled_grouped_mm.default])
def meta_scaled_grouped_mm(
mat_a: torch.Tensor,
mat_b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
offs: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
scale_result: Optional[torch.Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
use_fast_accum: bool = False,
):
# Check dimensions
torch._check(
mat_a.dim() == 2 or mat_a.dim() == 3, lambda: "mat_a has to be 2 or 3d"
)
torch._check(
mat_b.dim() == 2 or mat_b.dim() == 3, lambda: "mat_b has to be 2 or 3d"
)
a_is_2d = mat_a.dim() == 2
b_is_2d = mat_b.dim() == 2
# Check offsets
torch._check(
(offs is not None) == (a_is_2d or b_is_2d),
lambda: "Have to provide offsets if there is a 2d matrix",
)
if offs is not None:
torch._check(offs.dim() == 1, lambda: "offs has to be 1D")
torch._check(offs.dtype == torch.int, lambda: "Offsets have to be int32")
# Check matrix sizes
torch._check(
mat_a.size(-1) % 16 == 0,
lambda: f"Expected trailing dimension of mat_a to be divisible by 16 but got mat1 shape: {mat_a.size()}",
)
torch._check(
mat_b.size(-2) % 16 == 0 and mat_b.size(-1) % 16 == 0,
lambda: f"Expected mat_b shape to be divisible by 16 but got mat_b shape: {mat_b.size()}",
)
# Check scales
torch._check(
scale_a.dtype == torch.float and scale_b.dtype == torch.float,
lambda: "Both scale_a and scale_b must be float (fp32) tensors.",
)
# Check scale dimensions
scale_multiplier = offs.size(0) if (a_is_2d and b_is_2d) else 1 # type: ignore[union-attr]
if a_is_2d:
torch._check(
scale_a.dim() == 1,
lambda: f"scale must be a 1D tensor for 2D mat_a, but got {scale_a.dim()}D",
)
torch._check(scale_a.is_contiguous(), lambda: "scale_a must be contiguous")
torch._check(
scale_a.size(0) == mat_a.size(0) * scale_multiplier,
lambda: "scale must have the same length as mat_a",
)
else:
torch._check(
scale_a.dim() == 2,
lambda: f"scale must be a 2D tensor for 3D mat_a, but got {scale_a.dim()}D",
)
torch._check(
scale_a.stride(1) == 1,
lambda: "scale_a must be contiguous in the last dimension",
)
torch._check(
scale_a.size(0) == mat_a.size(0),
lambda: "scale must have the same batch dimension as mat_a",
)
torch._check(
scale_a.size(1) == mat_a.size(1),
lambda: "scale must have the same first dimension as mat_a",
)
# Similar checks for scale_b
if b_is_2d:
torch._check(
scale_b.dim() == 1,
lambda: f"scale must be a 1D tensor for 2D mat_b, but got {scale_b.dim()}D",
)
torch._check(scale_b.is_contiguous(), lambda: "scale_b must be contiguous")
torch._check(
scale_b.size(0) == mat_b.size(1) * scale_multiplier,
lambda: "scale must have the same length as mat_b",
)
else:
torch._check(
scale_b.dim() == 2,
lambda: f"scale must be a 2D tensor for 3D mat_b, but got {scale_b.dim()}D",
)
torch._check(
scale_b.stride(1) == 1,
lambda: "scale_b must be contiguous in the last dimension",
)
torch._check(
scale_b.size(0) == mat_b.size(0),
lambda: "scale must have the same batch dimension as mat_b",
)
torch._check(
scale_b.size(1) == mat_b.size(2),
lambda: "scale must have the same last dimension as mat_b",
)
# Check bias
torch._check(bias is None, lambda: "Bias not supported yet")
# Check output dtype
out_dtype_ = out_dtype if out_dtype is not None else mat_a.dtype
torch._check(
out_dtype_ == torch.bfloat16,
lambda: "Only bf16 high precision output types are supported for grouped gemm",
)
# Compute output size
out_size = _compute_grouped_gemm_output_size(mat_a, mat_b, offs)
out = mat_a.new_empty(out_size, dtype=out_dtype)
return out
@register_meta(aten._softmax)
@out_wrapper()
def softmax(x: Tensor, dim: int, half_to_float: bool) -> Tensor: