mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Update test after CUTLASS upgrade (#157903)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157903 Approved by: https://github.com/ngimel
This commit is contained in:
parent
8c5b070d1f
commit
a3ec6d64b2
|
|
@ -48,12 +48,7 @@ __global__ void prepare_grouped_gemm_data(
|
|||
int32_t start = tid == 0 ? 0 : offs[tid - 1];
|
||||
delta = offs[tid] - start;
|
||||
if (K < 0) {
|
||||
if (!a_row_major && b_row_major) {
|
||||
CUDA_KERNEL_ASSERT(delta >=0 && "expected ofsets to be greater or equal 0\n");
|
||||
} else {
|
||||
// CUTLASS cannot handle delta=0 here.
|
||||
CUDA_KERNEL_ASSERT(delta >0 && "expected ofsets to be greater than 0\n");
|
||||
}
|
||||
CUDA_KERNEL_ASSERT(delta >=0 && "expected ofsets to be greater or equal 0\n");
|
||||
}
|
||||
|
||||
// TMA transfers require global memory tensor addresses to be
|
||||
|
|
|
|||
|
|
@ -519,9 +519,9 @@ class TestMatmulCuda(TestCase):
|
|||
m_align = (m + align - 1) // align * align
|
||||
n_align = (n + align - 1) // align * align
|
||||
if not a_row_major and not b_row_major:
|
||||
offs = torch.tensor([1, 3, 4, 6, 7], device=device, dtype=dtype_offset)
|
||||
offs = torch.tensor([0, 1, 6, 6, 7], device=device, dtype=dtype_offset)
|
||||
else:
|
||||
offs = torch.tensor([8, 16, 32, 37], device=device, dtype=dtype_offset)
|
||||
offs = torch.tensor([0, 8, 16, 16, 27], device=device, dtype=dtype_offset)
|
||||
ngroups = offs.shape[0]
|
||||
k = offs[-1]
|
||||
k_align = (k + align - 1) // align * align
|
||||
|
|
|
|||
|
|
@ -7678,7 +7678,7 @@ def _meta_grouped_mm_common(
|
|||
|
||||
@register_meta(aten._grouped_mm)
|
||||
@out_wrapper()
|
||||
def grouped_mm(
|
||||
def meta_grouped_mm(
|
||||
mat_a: Tensor,
|
||||
mat_b: Tensor,
|
||||
offs: Optional[Tensor] = None,
|
||||
|
|
@ -7697,7 +7697,7 @@ def grouped_mm(
|
|||
)
|
||||
|
||||
|
||||
@register_meta([aten._scaled_grouped_mm.default])
|
||||
@register_meta([aten._scaled_grouped_mm])
|
||||
def meta_scaled_grouped_mm(
|
||||
mat_a: torch.Tensor,
|
||||
mat_b: torch.Tensor,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user