mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[B200][MXFP8] Fix regex in test_blockwise_mxfp8_nvfp4_error_messages_recipe_mxfp8_cuda (#162180)
to unblock https://github.com/pytorch/pytorch/pull/159494 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162180 Approved by: https://github.com/Skylion007, https://github.com/drisspg, https://github.com/nWEIdia
This commit is contained in:
parent
9499c8761c
commit
c7e41071a0
|
|
@ -1799,10 +1799,9 @@ class TestFP8Matmul(TestCase):
|
||||||
# Test wrong scale tensor size for scale_a with correct dtype
|
# Test wrong scale tensor size for scale_a with correct dtype
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError,
|
RuntimeError,
|
||||||
re.escape(
|
f".*For Block[W,w]ise.*scaling.*scale_a should have {expected_a_size} "
|
||||||
f"For BlockWise scaling: Expected scale_a size to be {expected_a_size} "
|
f"elements.*"
|
||||||
f"but got {expected_a_size - 1}"
|
,
|
||||||
),
|
|
||||||
):
|
):
|
||||||
incorrect_size_a = torch.ones(expected_a_size - 1, device=device, dtype=scale_dtype)
|
incorrect_size_a = torch.ones(expected_a_size - 1, device=device, dtype=scale_dtype)
|
||||||
correct_size_b = torch.ones(expected_b_size, device=device, dtype=scale_dtype)
|
correct_size_b = torch.ones(expected_b_size, device=device, dtype=scale_dtype)
|
||||||
|
|
@ -1817,10 +1816,9 @@ class TestFP8Matmul(TestCase):
|
||||||
# Test wrong scale tensor size for scale_b with correct dtype
|
# Test wrong scale tensor size for scale_b with correct dtype
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError,
|
RuntimeError,
|
||||||
re.escape(
|
f"For Block[W,w]ise.*scaling.*scale_b should have {expected_b_size} "
|
||||||
f"For BlockWise scaling: Expected scale_b size to be {expected_b_size} "
|
f"elements.*"
|
||||||
f"but got {expected_b_size + 1}"
|
,
|
||||||
),
|
|
||||||
):
|
):
|
||||||
correct_size_a = torch.ones(expected_a_size, device=device, dtype=scale_dtype)
|
correct_size_a = torch.ones(expected_a_size, device=device, dtype=scale_dtype)
|
||||||
incorrect_size_b = torch.ones(expected_b_size + 1, device=device, dtype=scale_dtype)
|
incorrect_size_b = torch.ones(expected_b_size + 1, device=device, dtype=scale_dtype)
|
||||||
|
|
@ -1835,9 +1833,8 @@ class TestFP8Matmul(TestCase):
|
||||||
# Test non-contiguous scale tensors with correct dtype
|
# Test non-contiguous scale tensors with correct dtype
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError,
|
RuntimeError,
|
||||||
re.escape(
|
"For Block[W,w]ise.*scaling.*both should be contiguous"
|
||||||
"For BlockWise scaling: Both scale_a and scale_b must be contiguous"
|
,
|
||||||
),
|
|
||||||
):
|
):
|
||||||
non_contiguous_a = torch.ones(expected_a_size * 2, device=device, dtype=scale_dtype)[::2]
|
non_contiguous_a = torch.ones(expected_a_size * 2, device=device, dtype=scale_dtype)[::2]
|
||||||
contiguous_b = torch.ones(expected_b_size, device=device, dtype=scale_dtype)
|
contiguous_b = torch.ones(expected_b_size, device=device, dtype=scale_dtype)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user