Update test after CUTLASS upgrade (#157903)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157903
Approved by: https://github.com/ngimel
This commit is contained in:
Aleksandar Samardžić 2025-07-09 17:12:00 +00:00 committed by PyTorch MergeBot
parent 8c5b070d1f
commit a3ec6d64b2
3 changed files with 5 additions and 10 deletions

View File

@ -48,12 +48,7 @@ __global__ void prepare_grouped_gemm_data(
int32_t start = tid == 0 ? 0 : offs[tid - 1];
delta = offs[tid] - start;
if (K < 0) {
if (!a_row_major && b_row_major) {
CUDA_KERNEL_ASSERT(delta >=0 && "expected ofsets to be greater or equal 0\n");
} else {
// CUTLASS cannot handle delta=0 here.
CUDA_KERNEL_ASSERT(delta >0 && "expected ofsets to be greater than 0\n");
}
CUDA_KERNEL_ASSERT(delta >=0 && "expected ofsets to be greater or equal 0\n");
}
// TMA transfers require global memory tensor addresses to be

View File

@ -519,9 +519,9 @@ class TestMatmulCuda(TestCase):
m_align = (m + align - 1) // align * align
n_align = (n + align - 1) // align * align
if not a_row_major and not b_row_major:
offs = torch.tensor([1, 3, 4, 6, 7], device=device, dtype=dtype_offset)
offs = torch.tensor([0, 1, 6, 6, 7], device=device, dtype=dtype_offset)
else:
offs = torch.tensor([8, 16, 32, 37], device=device, dtype=dtype_offset)
offs = torch.tensor([0, 8, 16, 16, 27], device=device, dtype=dtype_offset)
ngroups = offs.shape[0]
k = offs[-1]
k_align = (k + align - 1) // align * align

View File

@ -7678,7 +7678,7 @@ def _meta_grouped_mm_common(
@register_meta(aten._grouped_mm)
@out_wrapper()
def grouped_mm(
def meta_grouped_mm(
mat_a: Tensor,
mat_b: Tensor,
offs: Optional[Tensor] = None,
@ -7697,7 +7697,7 @@ def grouped_mm(
)
@register_meta([aten._scaled_grouped_mm.default])
@register_meta([aten._scaled_grouped_mm])
def meta_scaled_grouped_mm(
mat_a: torch.Tensor,
mat_b: torch.Tensor,