mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Torch] Fix crash when comparing fp8 tensors that have more than 1 dimension (#153508)
Summary: `torch.nonzero` returns as many items as the number of dimensions, so we shouldn't expect a single element for the indices. Test Plan: CI Differential Revision: D74539233 Pull Request resolved: https://github.com/pytorch/pytorch/pull/153508 Approved by: https://github.com/exclamaforte
This commit is contained in:
parent
b297e01f4b
commit
6e107899da
|
|
@ -951,19 +951,33 @@ class TestAssertCloseErrorMessage(TestCase):
|
||||||
torch.float8_e5m2fnuz,
|
torch.float8_e5m2fnuz,
|
||||||
torch.float8_e8m0fnu,
|
torch.float8_e8m0fnu,
|
||||||
]:
|
]:
|
||||||
w = torch.tensor([3.14, 1.0], dtype=dtype)
|
w_vector = torch.tensor([3.14, 1.0], dtype=dtype)
|
||||||
x = torch.tensor([1.0, 3.14], dtype=dtype)
|
x_vector = torch.tensor([1.0, 3.14], dtype=dtype)
|
||||||
y = torch.tensor([3.14, 3.14], dtype=dtype)
|
y_vector = torch.tensor([3.14, 3.14], dtype=dtype)
|
||||||
z = torch.tensor([1.0, 3.14], dtype=dtype)
|
z_vector = torch.tensor([1.0, 3.14], dtype=dtype)
|
||||||
for fn in assert_close_with_inputs(x, y):
|
|
||||||
with self.assertRaisesRegex(AssertionError, re.escape("The first mismatched element is at index 0")):
|
|
||||||
fn()
|
|
||||||
|
|
||||||
for fn in assert_close_with_inputs(w, y):
|
for additional_dims in range(4):
|
||||||
with self.assertRaisesRegex(AssertionError, re.escape("The first mismatched element is at index 1")):
|
new_shape = list(w_vector.shape) + ([1] * additional_dims)
|
||||||
|
w_tensor = w_vector.reshape(new_shape)
|
||||||
|
x_tensor = x_vector.reshape(new_shape)
|
||||||
|
y_tensor = y_vector.reshape(new_shape)
|
||||||
|
z_tensor = z_vector.reshape(new_shape)
|
||||||
|
|
||||||
|
for fn in assert_close_with_inputs(x_tensor, y_tensor):
|
||||||
|
expected_shape = (0,) + (0,) * (additional_dims)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
AssertionError, re.escape(f"The first mismatched element is at index {expected_shape}")
|
||||||
|
):
|
||||||
|
fn()
|
||||||
|
|
||||||
|
for fn in assert_close_with_inputs(w_tensor, y_tensor):
|
||||||
|
expected_shape = (1,) + (0,) * (additional_dims)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
AssertionError, re.escape(f"The first mismatched element is at index {expected_shape}")
|
||||||
|
):
|
||||||
|
fn()
|
||||||
|
for fn in assert_close_with_inputs(x_tensor, z_tensor):
|
||||||
fn()
|
fn()
|
||||||
for fn in assert_close_with_inputs(x, z):
|
|
||||||
fn()
|
|
||||||
|
|
||||||
def test_abs_diff_scalar(self):
|
def test_abs_diff_scalar(self):
|
||||||
actual = 3
|
actual = 3
|
||||||
|
|
|
||||||
|
|
@ -133,7 +133,7 @@ def _make_bitwise_mismatch_msg(
|
||||||
default_identifier: str,
|
default_identifier: str,
|
||||||
identifier: Optional[Union[str, Callable[[str], str]]] = None,
|
identifier: Optional[Union[str, Callable[[str], str]]] = None,
|
||||||
extra: Optional[str] = None,
|
extra: Optional[str] = None,
|
||||||
first_mismatch_idx: Optional[int] = None,
|
first_mismatch_idx: Optional[tuple[int]] = None,
|
||||||
):
|
):
|
||||||
"""Makes a mismatch error message for bitwise values.
|
"""Makes a mismatch error message for bitwise values.
|
||||||
|
|
||||||
|
|
@ -143,7 +143,7 @@ def _make_bitwise_mismatch_msg(
|
||||||
``default_identifier``. Can be passed as callable in which case it will be called with
|
``default_identifier``. Can be passed as callable in which case it will be called with
|
||||||
``default_identifier`` to create the description at runtime.
|
``default_identifier`` to create the description at runtime.
|
||||||
extra (Optional[str]): Extra information to be placed after the message header and the mismatch statistics.
|
extra (Optional[str]): Extra information to be placed after the message header and the mismatch statistics.
|
||||||
first_mismatch_idx (Optional[int]): the index of the first mismatch.
|
first_mismatch_idx (Optional[tuple[int]]): the index of the first mismatch, for each dimension.
|
||||||
"""
|
"""
|
||||||
if identifier is None:
|
if identifier is None:
|
||||||
identifier = default_identifier
|
identifier = default_identifier
|
||||||
|
|
@ -296,12 +296,12 @@ def make_tensor_mismatch_msg(
|
||||||
)
|
)
|
||||||
if actual.dtype.is_floating_point and actual.dtype.itemsize == 1:
|
if actual.dtype.is_floating_point and actual.dtype.itemsize == 1:
|
||||||
# skip checking for max_abs_diff and max_rel_diff for float8-like values
|
# skip checking for max_abs_diff and max_rel_diff for float8-like values
|
||||||
first_mismatch_idx = torch.nonzero(~matches, as_tuple=False)[0].item()
|
first_mismatch_idx = tuple(torch.nonzero(~matches, as_tuple=False)[0].tolist())
|
||||||
return _make_bitwise_mismatch_msg(
|
return _make_bitwise_mismatch_msg(
|
||||||
default_identifier="Tensor-likes",
|
default_identifier="Tensor-likes",
|
||||||
identifier=identifier,
|
identifier=identifier,
|
||||||
extra=extra,
|
extra=extra,
|
||||||
first_mismatch_idx=int(first_mismatch_idx),
|
first_mismatch_idx=first_mismatch_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
actual_flat = actual.flatten()
|
actual_flat = actual.flatten()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user