Update test after CUTLASS upgrade (#157903)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157903
Approved by: https://github.com/ngimel
This commit is contained in:
Aleksandar Samardžić 2025-07-09 17:12:00 +00:00 committed by PyTorch MergeBot
parent 8c5b070d1f
commit a3ec6d64b2
3 changed files with 5 additions and 10 deletions

View File

@ -48,12 +48,7 @@ __global__ void prepare_grouped_gemm_data(
int32_t start = tid == 0 ? 0 : offs[tid - 1]; int32_t start = tid == 0 ? 0 : offs[tid - 1];
delta = offs[tid] - start; delta = offs[tid] - start;
if (K < 0) { if (K < 0) {
if (!a_row_major && b_row_major) { CUDA_KERNEL_ASSERT(delta >=0 && "expected ofsets to be greater or equal 0\n");
CUDA_KERNEL_ASSERT(delta >=0 && "expected ofsets to be greater or equal 0\n");
} else {
// CUTLASS cannot handle delta=0 here.
CUDA_KERNEL_ASSERT(delta >0 && "expected ofsets to be greater than 0\n");
}
} }
// TMA transfers require global memory tensor addresses to be // TMA transfers require global memory tensor addresses to be

View File

@ -519,9 +519,9 @@ class TestMatmulCuda(TestCase):
m_align = (m + align - 1) // align * align m_align = (m + align - 1) // align * align
n_align = (n + align - 1) // align * align n_align = (n + align - 1) // align * align
if not a_row_major and not b_row_major: if not a_row_major and not b_row_major:
offs = torch.tensor([1, 3, 4, 6, 7], device=device, dtype=dtype_offset) offs = torch.tensor([0, 1, 6, 6, 7], device=device, dtype=dtype_offset)
else: else:
offs = torch.tensor([8, 16, 32, 37], device=device, dtype=dtype_offset) offs = torch.tensor([0, 8, 16, 16, 27], device=device, dtype=dtype_offset)
ngroups = offs.shape[0] ngroups = offs.shape[0]
k = offs[-1] k = offs[-1]
k_align = (k + align - 1) // align * align k_align = (k + align - 1) // align * align

View File

@ -7678,7 +7678,7 @@ def _meta_grouped_mm_common(
@register_meta(aten._grouped_mm) @register_meta(aten._grouped_mm)
@out_wrapper() @out_wrapper()
def grouped_mm( def meta_grouped_mm(
mat_a: Tensor, mat_a: Tensor,
mat_b: Tensor, mat_b: Tensor,
offs: Optional[Tensor] = None, offs: Optional[Tensor] = None,
@ -7697,7 +7697,7 @@ def grouped_mm(
) )
@register_meta([aten._scaled_grouped_mm.default]) @register_meta([aten._scaled_grouped_mm])
def meta_scaled_grouped_mm( def meta_scaled_grouped_mm(
mat_a: torch.Tensor, mat_a: torch.Tensor,
mat_b: torch.Tensor, mat_b: torch.Tensor,