mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
error out on negative offs or on K=0 in group gemm (#153226)
Error out if K=0 in one of the grouped gemms to avoid hangs in #152668 Also, adds meta function for _scaled_grouped_mm (TODO: do the same for _grouped_mm, unless it's done already) One weird thing I'm seeing, when running all grouped_gemm tests, I'm erroring out with ``` File "/data/users/ngimel/pytorch/torch/_inductor/graph.py", line 1246, in call_function out = lowerings[target](*args, **kwargs) # type: ignore[index] File "/data/users/ngimel/pytorch/torch/_inductor/lowering.py", line 445, in wrapped out = decomp_fn(*args, **kwargs) File "/data/users/ngimel/pytorch/torch/_inductor/kernel/mm_scaled_grouped.py", line 444, in tuned_scaled_grouped_mm if is_nonzero and can_use_triton_kernel(mat_a, mat_b, offs, bias): File "/data/users/ngimel/pytorch/torch/_inductor/kernel/mm_scaled_grouped.py", line 375, in can_use_triton_kernel offs is not None File "/home/ngimel/.conda/envs/pytorch_monarch/lib/python3.10/site-packages/sympy/core/relational.py", line 516, in __bool__ raise TypeError("cannot determine truth value of Relational") torch._inductor.exc.InductorError: LoweringException: TypeError: cannot determine truth value of Relational ``` which is weird, there's no relational that sympy has to evaluate in `offs is not None`, and when running this test separately (`test_scaled_grouped_gemm_2d_3d_fast_accum_True_strided_False_use_torch_compile_True_cuda`) it passes. I suspect some autotuning cache has to be reset between runs, but don't know what to look for. Edit: that error is "fixed" by setting `dynamic=False`, now with correct meat function something's wrong with dynamic shapes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153226 Approved by: https://github.com/kwen2501
This commit is contained in:
parent
639793c17e
commit
9c99ea2991
|
|
@ -49,8 +49,8 @@ __global__ void prepare_grouped_gemm_data(
|
|||
delta = offs[tid] - start;
|
||||
int align = 16 / sizeof(DtypeA);
|
||||
CUDA_KERNEL_ASSERT(
|
||||
delta % align == 0 &&
|
||||
"expected dynamic dimension byte size to be multiple of 16 \n");
|
||||
delta >=0 && delta % align == 0 &&
|
||||
"expected dynamic dimension byte size to be non-negative multiple of 16 \n");
|
||||
}
|
||||
int64_t lda, ldb, ldoutput;
|
||||
if (M < 0) {
|
||||
|
|
@ -81,6 +81,7 @@ __global__ void prepare_grouped_gemm_data(
|
|||
} else if (K < 0) {
|
||||
// A, B is 2d, output is 3d
|
||||
K = delta;
|
||||
CUDA_KERNEL_ASSERT(delta > 0 && "can't handle K=0");
|
||||
lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1];
|
||||
ldb = b_row_major ? tensor_StrideB[0] : tensor_StrideB[1];
|
||||
ldoutput = tensor_StrideOutput[1];
|
||||
|
|
|
|||
|
|
@ -268,8 +268,9 @@ class TestMatmulCuda(TestCase):
|
|||
out_ref = torch.mm(a, b.t())
|
||||
out_ref.backward(gO)
|
||||
self.assertEqual(out, out_ref)
|
||||
self.assertEqual(agrad, a.grad)
|
||||
self.assertEqual(bgrad, b.grad)
|
||||
if agrad is not None:
|
||||
self.assertEqual(agrad, a.grad)
|
||||
self.assertEqual(bgrad, b.grad)
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
||||
@xfailIfSM100OrLater
|
||||
|
|
@ -338,22 +339,28 @@ class TestMatmulCuda(TestCase):
|
|||
self.assertTrue(a_contig.is_contiguous() is not strided)
|
||||
b_contig = b if b_row_major else b.transpose(-2, -1)
|
||||
self.assertTrue(b_contig.is_contiguous() is not strided)
|
||||
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
|
||||
|
||||
out = torch._grouped_mm(a, b.transpose(-2, -1), offs=offs,
|
||||
out_dtype=torch.bfloat16)
|
||||
gO = torch.rand_like(out)
|
||||
out.backward(gO)
|
||||
offs_cpu = offs.cpu()
|
||||
alist, agradlist, gOlist, outlist = [], [], [], []
|
||||
start = 0
|
||||
for i in range(n_groups):
|
||||
alist.append(a[start:offs_cpu[i]])
|
||||
agradlist.append(a.grad[start:offs_cpu[i]])
|
||||
outlist.append(out[start:offs_cpu[i]])
|
||||
gOlist.append(gO[start:offs_cpu[i]])
|
||||
start = offs_cpu[i]
|
||||
self.grouped_mm_helper(alist, b, gOlist, agradlist, b.grad, outlist)
|
||||
for check_zero_size in (False, True):
|
||||
a.grad = None
|
||||
b.grad = None
|
||||
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
|
||||
if check_zero_size:
|
||||
offs[0] = offs[1]
|
||||
out = torch._grouped_mm(a, b.transpose(-2, -1), offs=offs,
|
||||
out_dtype=torch.bfloat16)
|
||||
gO = torch.rand_like(out)
|
||||
if not check_zero_size:
|
||||
out.backward(gO)
|
||||
offs_cpu = offs.cpu()
|
||||
alist, agradlist, gOlist, outlist = [], [], [], []
|
||||
bgradlist = [None] * n_groups if check_zero_size else b.grad
|
||||
start = 0
|
||||
for i in range(n_groups):
|
||||
alist.append(a[start:offs_cpu[i]])
|
||||
agradlist.append(None if check_zero_size else a.grad[start:offs_cpu[i]])
|
||||
outlist.append(out[start:offs_cpu[i]])
|
||||
gOlist.append(gO[start:offs_cpu[i]])
|
||||
start = offs_cpu[i]
|
||||
self.grouped_mm_helper(alist, b, gOlist, agradlist, bgradlist, outlist)
|
||||
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
||||
|
|
@ -418,21 +425,26 @@ class TestMatmulCuda(TestCase):
|
|||
self.assertTrue(a_contig.is_contiguous() is not strided)
|
||||
b_contig = b if b_row_major else b.transpose(-2, -1)
|
||||
self.assertTrue(b_contig.is_contiguous() is not strided)
|
||||
offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32)
|
||||
out = torch._grouped_mm(a, b.transpose(-2, -1), offs=offs,
|
||||
out_dtype=torch.bfloat16)
|
||||
gO = torch.rand_like(out)
|
||||
out.backward(gO)
|
||||
offs_cpu = offs.cpu()
|
||||
blist, outlist, bgradlist, gOlist = [], [], [], []
|
||||
start = 0
|
||||
for i in range(n_groups):
|
||||
blist.append(b[start:offs_cpu[i]])
|
||||
bgradlist.append(b.grad[start:offs_cpu[i]])
|
||||
outlist.append(out[:, start:offs_cpu[i]])
|
||||
gOlist.append(gO[:, start:offs_cpu[i]])
|
||||
start = offs_cpu[i]
|
||||
self.grouped_mm_helper(a, blist, gOlist, a.grad, bgradlist, outlist)
|
||||
for check_zero_size in (False, True):
|
||||
offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32)
|
||||
if check_zero_size:
|
||||
offs[0] = offs[1]
|
||||
out = torch._grouped_mm(a, b.transpose(-2, -1), offs=offs,
|
||||
out_dtype=torch.bfloat16)
|
||||
gO = torch.rand_like(out)
|
||||
if not check_zero_size:
|
||||
out.backward(gO)
|
||||
offs_cpu = offs.cpu()
|
||||
blist, outlist, bgradlist, gOlist = [], [], [], []
|
||||
agradlist = [None] * n_groups if check_zero_size else a.grad
|
||||
start = 0
|
||||
for i in range(n_groups):
|
||||
blist.append(b[start:offs_cpu[i]])
|
||||
bgradlist.append(b.grad[start:offs_cpu[i]])
|
||||
outlist.append(out[:, start:offs_cpu[i]])
|
||||
gOlist.append(gO[:, start:offs_cpu[i]])
|
||||
start = offs_cpu[i]
|
||||
self.grouped_mm_helper(a, blist, gOlist, agradlist, bgradlist, outlist)
|
||||
|
||||
|
||||
@onlyCUDA
|
||||
|
|
@ -1611,24 +1623,27 @@ class TestFP8Matmul(TestCase):
|
|||
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
|
||||
self.assertTrue(a.is_contiguous() is not strided)
|
||||
self.assertTrue(b.is_contiguous() is not strided)
|
||||
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
|
||||
scale_a = torch.arange(n_groups * m, device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.ones(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)
|
||||
for check_zero_size in (True, False):
|
||||
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
|
||||
if check_zero_size:
|
||||
offs[0] = offs[1]
|
||||
scale_a = torch.arange(n_groups * m, device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.ones(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)
|
||||
|
||||
f = torch._scaled_grouped_mm
|
||||
f = torch.compile(f) if use_torch_compile else f
|
||||
out = f(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs,
|
||||
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
|
||||
f = torch._scaled_grouped_mm
|
||||
f = torch.compile(f, dynamic=False) if use_torch_compile else f
|
||||
out = f(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs,
|
||||
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
|
||||
|
||||
offs_cpu = offs.cpu()
|
||||
alist, ascalelist, outlist = [], [], []
|
||||
start = 0
|
||||
for i in range(n_groups):
|
||||
alist.append(a[start:offs_cpu[i]])
|
||||
ascalelist.append(scale_a[start:offs_cpu[i]])
|
||||
outlist.append(out[start:offs_cpu[i]])
|
||||
start = offs_cpu[i]
|
||||
self.scaled_grouped_mm_helper(alist, b, ascalelist, scale_b, outlist, fast_accum)
|
||||
offs_cpu = offs.cpu()
|
||||
alist, ascalelist, outlist = [], [], []
|
||||
start = 0
|
||||
for i in range(n_groups):
|
||||
alist.append(a[start:offs_cpu[i]])
|
||||
ascalelist.append(scale_a[start:offs_cpu[i]])
|
||||
outlist.append(out[start:offs_cpu[i]])
|
||||
start = offs_cpu[i]
|
||||
self.scaled_grouped_mm_helper(alist, b, ascalelist, scale_b, outlist, fast_accum)
|
||||
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
||||
|
|
@ -1672,21 +1687,24 @@ class TestFP8Matmul(TestCase):
|
|||
self.assertTrue(b.is_contiguous() is not strided)
|
||||
scale_a = torch.arange(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m)
|
||||
scale_b = torch.arange(n_groups * n, device="cuda", dtype=torch.float32)
|
||||
offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32)
|
||||
for check_zero_size in (True, False):
|
||||
offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32)
|
||||
if check_zero_size:
|
||||
offs[0] = offs[1]
|
||||
|
||||
f = torch._scaled_grouped_mm
|
||||
f = torch.compile(f) if use_torch_compile else f
|
||||
out = f(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs,
|
||||
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
|
||||
offs_cpu = offs.cpu()
|
||||
blist, bscalelist, outlist = [], [], []
|
||||
start = 0
|
||||
for i in range(n_groups):
|
||||
blist.append(b[start:offs_cpu[i]])
|
||||
bscalelist.append(scale_b[start:offs_cpu[i]])
|
||||
outlist.append(out[:, start:offs_cpu[i]])
|
||||
start = offs_cpu[i]
|
||||
self.scaled_grouped_mm_helper(a, blist, scale_a, bscalelist, outlist, fast_accum)
|
||||
f = torch._scaled_grouped_mm
|
||||
f = torch.compile(f) if use_torch_compile else f
|
||||
out = f(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs,
|
||||
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
|
||||
offs_cpu = offs.cpu()
|
||||
blist, bscalelist, outlist = [], [], []
|
||||
start = 0
|
||||
for i in range(n_groups):
|
||||
blist.append(b[start:offs_cpu[i]])
|
||||
bscalelist.append(scale_b[start:offs_cpu[i]])
|
||||
outlist.append(out[:, start:offs_cpu[i]])
|
||||
start = offs_cpu[i]
|
||||
self.scaled_grouped_mm_helper(a, blist, scale_a, bscalelist, outlist, fast_accum)
|
||||
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg)
|
||||
|
|
@ -1869,7 +1887,7 @@ class TestMixedDtypesLinearCuda(TestCase):
|
|||
|
||||
instantiate_device_type_tests(TestMatmulCuda, globals(), except_for="cpu")
|
||||
instantiate_device_type_tests(TestMixedDtypesLinearCuda, globals(), except_for="cpu")
|
||||
instantiate_device_type_tests(TestFP8Matmul, globals())
|
||||
instantiate_device_type_tests(TestFP8Matmul, globals(), except_for="cpu")
|
||||
|
||||
if __name__ == '__main__':
|
||||
TestCase._default_dtype_check_enabled = True
|
||||
|
|
|
|||
|
|
@ -7272,6 +7272,30 @@ def sigmoid(self: Tensor) -> Tensor:
|
|||
return torch.empty_like(self, dtype=result_dtype)
|
||||
|
||||
|
||||
def _compute_grouped_gemm_output_size(mat1, mat2, offs):
|
||||
mat1_is_2d = mat1.dim() == 2
|
||||
mat2_is_2d = mat2.dim() == 2
|
||||
|
||||
if mat1_is_2d:
|
||||
if mat2_is_2d:
|
||||
return offs.size(0), mat1.size(0), mat2.size(1)
|
||||
else:
|
||||
torch._check(
|
||||
offs.size(0) == mat2.size(0), "matrix batch sizes have to match"
|
||||
)
|
||||
return mat1.size(0), mat2.size(-1)
|
||||
else:
|
||||
if mat2_is_2d:
|
||||
torch._check(
|
||||
offs.size(0) == mat1.size(0), "matrix batch sizes have to match"
|
||||
)
|
||||
return mat1.size(1), mat2.size(1)
|
||||
else:
|
||||
# regular bmm
|
||||
torch._check(mat1.size(0) == mat2.size(0), "batched dimension has to match")
|
||||
return mat1.size(0), mat1.size(1), mat2.size(-1)
|
||||
|
||||
|
||||
@register_meta(aten._grouped_mm)
|
||||
@out_wrapper()
|
||||
def grouped_mm(
|
||||
|
|
@ -7294,37 +7318,138 @@ def grouped_mm(
|
|||
out_dtype = out_dtype or mat1.dtype
|
||||
torch._check(bias is None, lambda: "bias not supported yet")
|
||||
|
||||
def _compute_grouped_gemm_output_size(mat1, mat2, offs):
|
||||
mat1_is_2d = mat1.dim() == 2
|
||||
mat2_is_2d = mat2.dim() == 2
|
||||
|
||||
if mat1_is_2d:
|
||||
if mat2_is_2d:
|
||||
return offs.size(0), mat1.size(0), mat2.size(1)
|
||||
else:
|
||||
torch._check(
|
||||
offs.size(0) == mat2.size(0), "matrix batch sizes have to match"
|
||||
)
|
||||
return mat1.size(0), mat2.size(-1)
|
||||
else:
|
||||
if mat2_is_2d:
|
||||
torch._check(
|
||||
offs.size(0) == mat1.size(0), "matrix batch sizes have to match"
|
||||
)
|
||||
return mat1.size(1), mat2.size(1)
|
||||
else:
|
||||
# regular bmm
|
||||
torch._check(
|
||||
mat1.size(0) == mat2.size(0), "batched dimension has to match"
|
||||
)
|
||||
return mat1.size(0), mat1.size(1), mat2.size(-1)
|
||||
|
||||
out_size = _compute_grouped_gemm_output_size(mat1, mat2, offs)
|
||||
out = mat1.new_empty(out_size, dtype=out_dtype)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@register_meta([aten._scaled_grouped_mm.default])
|
||||
def meta_scaled_grouped_mm(
|
||||
mat_a: torch.Tensor,
|
||||
mat_b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
offs: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
scale_result: Optional[torch.Tensor] = None,
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
use_fast_accum: bool = False,
|
||||
):
|
||||
# Check dimensions
|
||||
torch._check(
|
||||
mat_a.dim() == 2 or mat_a.dim() == 3, lambda: "mat_a has to be 2 or 3d"
|
||||
)
|
||||
torch._check(
|
||||
mat_b.dim() == 2 or mat_b.dim() == 3, lambda: "mat_b has to be 2 or 3d"
|
||||
)
|
||||
|
||||
a_is_2d = mat_a.dim() == 2
|
||||
b_is_2d = mat_b.dim() == 2
|
||||
|
||||
# Check offsets
|
||||
torch._check(
|
||||
(offs is not None) == (a_is_2d or b_is_2d),
|
||||
lambda: "Have to provide offsets if there is a 2d matrix",
|
||||
)
|
||||
|
||||
if offs is not None:
|
||||
torch._check(offs.dim() == 1, lambda: "offs has to be 1D")
|
||||
torch._check(offs.dtype == torch.int, lambda: "Offsets have to be int32")
|
||||
|
||||
# Check matrix sizes
|
||||
torch._check(
|
||||
mat_a.size(-1) % 16 == 0,
|
||||
lambda: f"Expected trailing dimension of mat_a to be divisible by 16 but got mat1 shape: {mat_a.size()}",
|
||||
)
|
||||
torch._check(
|
||||
mat_b.size(-2) % 16 == 0 and mat_b.size(-1) % 16 == 0,
|
||||
lambda: f"Expected mat_b shape to be divisible by 16 but got mat_b shape: {mat_b.size()}",
|
||||
)
|
||||
|
||||
# Check scales
|
||||
torch._check(
|
||||
scale_a.dtype == torch.float and scale_b.dtype == torch.float,
|
||||
lambda: "Both scale_a and scale_b must be float (fp32) tensors.",
|
||||
)
|
||||
|
||||
# Check scale dimensions
|
||||
scale_multiplier = offs.size(0) if (a_is_2d and b_is_2d) else 1 # type: ignore[union-attr]
|
||||
|
||||
if a_is_2d:
|
||||
torch._check(
|
||||
scale_a.dim() == 1,
|
||||
lambda: f"scale must be a 1D tensor for 2D mat_a, but got {scale_a.dim()}D",
|
||||
)
|
||||
torch._check(scale_a.is_contiguous(), lambda: "scale_a must be contiguous")
|
||||
torch._check(
|
||||
scale_a.size(0) == mat_a.size(0) * scale_multiplier,
|
||||
lambda: "scale must have the same length as mat_a",
|
||||
)
|
||||
else:
|
||||
torch._check(
|
||||
scale_a.dim() == 2,
|
||||
lambda: f"scale must be a 2D tensor for 3D mat_a, but got {scale_a.dim()}D",
|
||||
)
|
||||
torch._check(
|
||||
scale_a.stride(1) == 1,
|
||||
lambda: "scale_a must be contiguous in the last dimension",
|
||||
)
|
||||
torch._check(
|
||||
scale_a.size(0) == mat_a.size(0),
|
||||
lambda: "scale must have the same batch dimension as mat_a",
|
||||
)
|
||||
torch._check(
|
||||
scale_a.size(1) == mat_a.size(1),
|
||||
lambda: "scale must have the same first dimension as mat_a",
|
||||
)
|
||||
|
||||
# Similar checks for scale_b
|
||||
if b_is_2d:
|
||||
torch._check(
|
||||
scale_b.dim() == 1,
|
||||
lambda: f"scale must be a 1D tensor for 2D mat_b, but got {scale_b.dim()}D",
|
||||
)
|
||||
torch._check(scale_b.is_contiguous(), lambda: "scale_b must be contiguous")
|
||||
torch._check(
|
||||
scale_b.size(0) == mat_b.size(1) * scale_multiplier,
|
||||
lambda: "scale must have the same length as mat_b",
|
||||
)
|
||||
else:
|
||||
torch._check(
|
||||
scale_b.dim() == 2,
|
||||
lambda: f"scale must be a 2D tensor for 3D mat_b, but got {scale_b.dim()}D",
|
||||
)
|
||||
torch._check(
|
||||
scale_b.stride(1) == 1,
|
||||
lambda: "scale_b must be contiguous in the last dimension",
|
||||
)
|
||||
torch._check(
|
||||
scale_b.size(0) == mat_b.size(0),
|
||||
lambda: "scale must have the same batch dimension as mat_b",
|
||||
)
|
||||
torch._check(
|
||||
scale_b.size(1) == mat_b.size(2),
|
||||
lambda: "scale must have the same last dimension as mat_b",
|
||||
)
|
||||
|
||||
# Check bias
|
||||
torch._check(bias is None, lambda: "Bias not supported yet")
|
||||
|
||||
# Check output dtype
|
||||
out_dtype_ = out_dtype if out_dtype is not None else mat_a.dtype
|
||||
torch._check(
|
||||
out_dtype_ == torch.bfloat16,
|
||||
lambda: "Only bf16 high precision output types are supported for grouped gemm",
|
||||
)
|
||||
|
||||
# Compute output size
|
||||
out_size = _compute_grouped_gemm_output_size(mat_a, mat_b, offs)
|
||||
out = mat_a.new_empty(out_size, dtype=out_dtype)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@register_meta(aten._softmax)
|
||||
@out_wrapper()
|
||||
def softmax(x: Tensor, dim: int, half_to_float: bool) -> Tensor:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user