mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[TEST][ATen][CUDA] Skip row-wise scaled matrix mmultiplication tests on sm_120+ (#152814)
The float8 row-wise scaled matmuls are not supported on Blackwell yet. This PR adds skips to those tests to decrease the noise on `sm_120+` machines. Pull Request resolved: https://github.com/pytorch/pytorch/pull/152814 Approved by: https://github.com/eqy, https://github.com/Skylion007
This commit is contained in:
parent
4b8b7c7fb9
commit
086e2c2399
|
|
@ -24,6 +24,7 @@ from torch.testing._internal.common_cuda import (
|
|||
SM89OrLater,
|
||||
SM90OrLater,
|
||||
xfailIfSM100OrLater,
|
||||
xfailIfSM120OrLater,
|
||||
_get_torch_cuda_version,
|
||||
PLATFORM_SUPPORTS_FP8,
|
||||
PLATFORM_SUPPORTS_MX_GEMM,
|
||||
|
|
@ -1012,8 +1013,9 @@ class TestFP8Matmul(TestCase):
|
|||
out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True)
|
||||
self.assertEqual(out_fp8, out_fp8_s)
|
||||
|
||||
@xfailIfSM120OrLater
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89+ specific")
|
||||
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89-sm100 specific")
|
||||
@parametrize("use_fast_accum", [True, False])
|
||||
def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> None:
|
||||
M, K, N = (1024, 512, 2048)
|
||||
|
|
@ -1117,8 +1119,9 @@ class TestFP8Matmul(TestCase):
|
|||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
@xfailIfSM120OrLater
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89+ specific")
|
||||
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89-sm100 specific")
|
||||
@parametrize("base_dtype", [torch.bfloat16])
|
||||
def test_scaled_mm_vs_emulated_row_wise(self, base_dtype):
|
||||
torch.manual_seed(42)
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ SM80OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_devic
|
|||
SM89OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9))
|
||||
SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0))
|
||||
SM100OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (10, 0))
|
||||
SM120OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (12, 0))
|
||||
|
||||
IS_THOR = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 10
|
||||
and torch.cuda.get_device_capability()[1] > 0)
|
||||
|
|
@ -335,6 +336,9 @@ def xfailIfSM89(func):
|
|||
def xfailIfSM100OrLater(func):
|
||||
return func if not SM100OrLater else unittest.expectedFailure(func)
|
||||
|
||||
def xfailIfSM120OrLater(func):
|
||||
return func if not SM120OrLater else unittest.expectedFailure(func)
|
||||
|
||||
def xfailIfDistributedNotSupported(func):
|
||||
return func if not (IS_MACOS or IS_JETSON) else unittest.expectedFailure(func)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user