mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
test_matmul_cuda: Refine MX test skipping (#161009)
Replace return unittest.skip with raise unittest.SkipTest to ensure that the test suite correctly reports skipped tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161009 Approved by: https://github.com/jeffdaily
This commit is contained in:
parent
a3a82e3da8
commit
543896fcf3
|
|
@ -1565,12 +1565,12 @@ class TestFP8Matmul(TestCase):
|
||||||
@parametrize("recipe", ["mxfp8", "mxfp4" if torch.version.hip else "nvfp4"])
|
@parametrize("recipe", ["mxfp8", "mxfp4" if torch.version.hip else "nvfp4"])
|
||||||
def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None:
|
def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None:
|
||||||
if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum:
|
if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum:
|
||||||
return unittest.skip("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping")
|
raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping")
|
||||||
|
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
M, K, N = mkn
|
M, K, N = mkn
|
||||||
if (recipe == "nvfp4" or recipe == "mxfp4") and K % 32 != 0:
|
if (recipe == "nvfp4" or recipe == "mxfp4") and K % 32 != 0:
|
||||||
return unittest.skip("K must be divisible by 32 for nvfp4/mxfp4 cublas gemm, skipping")
|
raise unittest.SkipTest("K must be divisible by 32 for nvfp4/mxfp4 cublas gemm, skipping")
|
||||||
|
|
||||||
fp4_scaling_dtype = torch.float8_e8m0fnu if torch.version.hip else torch.float8_e4m3fn
|
fp4_scaling_dtype = torch.float8_e8m0fnu if torch.version.hip else torch.float8_e4m3fn
|
||||||
BLOCK_SIZE = 32 if torch.version.hip else (16 if recipe == "nvfp4" else 32)
|
BLOCK_SIZE = 32 if torch.version.hip else (16 if recipe == "nvfp4" else 32)
|
||||||
|
|
@ -1718,7 +1718,7 @@ class TestFP8Matmul(TestCase):
|
||||||
|
|
||||||
elif test_case_name == "data_random_scales_from_data":
|
elif test_case_name == "data_random_scales_from_data":
|
||||||
if not K % BLOCK_SIZE == 0:
|
if not K % BLOCK_SIZE == 0:
|
||||||
return unittest.skip(f"this test is only defined for K a multiple of {BLOCK_SIZE}, skipping")
|
raise unittest.SkipTest(f"this test is only defined for K a multiple of {BLOCK_SIZE}, skipping")
|
||||||
require_exact_match = False
|
require_exact_match = False
|
||||||
# random data, scales from data
|
# random data, scales from data
|
||||||
A_ref = torch.randn((M, K), device=device, dtype=torch.bfloat16) * 1000
|
A_ref = torch.randn((M, K), device=device, dtype=torch.bfloat16) * 1000
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user