[Torch] Fix crash when comparing fp8 tensors that have more than 1 dimension (#153508)

Summary: `torch.nonzero` returns as many items as the number of dimensions, so we shouldn't expect a single element for the indices.

Test Plan: CI

Differential Revision: D74539233

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153508
Approved by: https://github.com/exclamaforte
This commit is contained in:
Pat Vignola 2025-05-15 08:41:46 +00:00 committed by PyTorch MergeBot
parent b297e01f4b
commit 6e107899da
2 changed files with 29 additions and 15 deletions

View File

@ -951,19 +951,33 @@ class TestAssertCloseErrorMessage(TestCase):
torch.float8_e5m2fnuz, torch.float8_e5m2fnuz,
torch.float8_e8m0fnu, torch.float8_e8m0fnu,
]: ]:
w = torch.tensor([3.14, 1.0], dtype=dtype) w_vector = torch.tensor([3.14, 1.0], dtype=dtype)
x = torch.tensor([1.0, 3.14], dtype=dtype) x_vector = torch.tensor([1.0, 3.14], dtype=dtype)
y = torch.tensor([3.14, 3.14], dtype=dtype) y_vector = torch.tensor([3.14, 3.14], dtype=dtype)
z = torch.tensor([1.0, 3.14], dtype=dtype) z_vector = torch.tensor([1.0, 3.14], dtype=dtype)
for fn in assert_close_with_inputs(x, y):
with self.assertRaisesRegex(AssertionError, re.escape("The first mismatched element is at index 0")):
fn()
for fn in assert_close_with_inputs(w, y): for additional_dims in range(4):
with self.assertRaisesRegex(AssertionError, re.escape("The first mismatched element is at index 1")): new_shape = list(w_vector.shape) + ([1] * additional_dims)
w_tensor = w_vector.reshape(new_shape)
x_tensor = x_vector.reshape(new_shape)
y_tensor = y_vector.reshape(new_shape)
z_tensor = z_vector.reshape(new_shape)
for fn in assert_close_with_inputs(x_tensor, y_tensor):
expected_shape = (0,) + (0,) * (additional_dims)
with self.assertRaisesRegex(
AssertionError, re.escape(f"The first mismatched element is at index {expected_shape}")
):
fn()
for fn in assert_close_with_inputs(w_tensor, y_tensor):
expected_shape = (1,) + (0,) * (additional_dims)
with self.assertRaisesRegex(
AssertionError, re.escape(f"The first mismatched element is at index {expected_shape}")
):
fn()
for fn in assert_close_with_inputs(x_tensor, z_tensor):
fn() fn()
for fn in assert_close_with_inputs(x, z):
fn()
def test_abs_diff_scalar(self): def test_abs_diff_scalar(self):
actual = 3 actual = 3

View File

@ -133,7 +133,7 @@ def _make_bitwise_mismatch_msg(
default_identifier: str, default_identifier: str,
identifier: Optional[Union[str, Callable[[str], str]]] = None, identifier: Optional[Union[str, Callable[[str], str]]] = None,
extra: Optional[str] = None, extra: Optional[str] = None,
first_mismatch_idx: Optional[int] = None, first_mismatch_idx: Optional[tuple[int]] = None,
): ):
"""Makes a mismatch error message for bitwise values. """Makes a mismatch error message for bitwise values.
@ -143,7 +143,7 @@ def _make_bitwise_mismatch_msg(
``default_identifier``. Can be passed as callable in which case it will be called with ``default_identifier``. Can be passed as callable in which case it will be called with
``default_identifier`` to create the description at runtime. ``default_identifier`` to create the description at runtime.
extra (Optional[str]): Extra information to be placed after the message header and the mismatch statistics. extra (Optional[str]): Extra information to be placed after the message header and the mismatch statistics.
first_mismatch_idx (Optional[int]): the index of the first mismatch. first_mismatch_idx (Optional[tuple[int]]): the index of the first mismatch, for each dimension.
""" """
if identifier is None: if identifier is None:
identifier = default_identifier identifier = default_identifier
@ -296,12 +296,12 @@ def make_tensor_mismatch_msg(
) )
if actual.dtype.is_floating_point and actual.dtype.itemsize == 1: if actual.dtype.is_floating_point and actual.dtype.itemsize == 1:
# skip checking for max_abs_diff and max_rel_diff for float8-like values # skip checking for max_abs_diff and max_rel_diff for float8-like values
first_mismatch_idx = torch.nonzero(~matches, as_tuple=False)[0].item() first_mismatch_idx = tuple(torch.nonzero(~matches, as_tuple=False)[0].tolist())
return _make_bitwise_mismatch_msg( return _make_bitwise_mismatch_msg(
default_identifier="Tensor-likes", default_identifier="Tensor-likes",
identifier=identifier, identifier=identifier,
extra=extra, extra=extra,
first_mismatch_idx=int(first_mismatch_idx), first_mismatch_idx=first_mismatch_idx,
) )
actual_flat = actual.flatten() actual_flat = actual.flatten()