mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
fix TensorLikePair origination (#70304)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70304 Without this patch `TensorLikePair` will try to instantiate everything although it should only do so for tensor-likes. This is problematic if it is used before a different pair that would be able to handle the inputs but never gets to do so, because `TensorLikePair` bails out before. ```python from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair assert_equal("a", "a", pair_types=(TensorLikePair, ObjectPair)) ``` ``` ValueError: Constructing a tensor from <class 'str'> failed with new(): invalid data type 'str'. ``` Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D33542995 Pulled By: mruberry fbshipit-source-id: 77a5cc0abad44356c3ec64c7ec46e84d166ab2dd
This commit is contained in:
parent
49a5b33a74
commit
928ca95ff0
|
|
@ -745,7 +745,7 @@ class TestAssertClose(TestCase):
|
|||
expected = "0"
|
||||
|
||||
for fn in assert_close_with_inputs(actual, expected):
|
||||
with self.assertRaisesRegex(ValueError, str(type(actual))):
|
||||
with self.assertRaisesRegex(TypeError, str(type(actual))):
|
||||
fn()
|
||||
|
||||
def test_mismatching_shape(self):
|
||||
|
|
|
|||
|
|
@ -570,23 +570,19 @@ class TensorLikePair(Pair):
|
|||
if not allow_subclasses and type(actual) is not type(expected):
|
||||
raise UnsupportedInputs()
|
||||
|
||||
actual, expected = [self._to_tensor(input, id=id) for input in (actual, expected)]
|
||||
actual, expected = [self._to_tensor(input) for input in (actual, expected)]
|
||||
for tensor in (actual, expected):
|
||||
self._check_supported(tensor, id=id)
|
||||
return actual, expected
|
||||
|
||||
def _to_tensor(self, tensor_like: Any, *, id: Tuple[Any, ...]) -> torch.Tensor:
|
||||
def _to_tensor(self, tensor_like: Any) -> torch.Tensor:
|
||||
if isinstance(tensor_like, torch.Tensor):
|
||||
return tensor_like
|
||||
|
||||
try:
|
||||
return torch.as_tensor(tensor_like)
|
||||
except Exception as error:
|
||||
raise ErrorMeta(
|
||||
ValueError,
|
||||
f"Constructing a tensor from {type(tensor_like)} failed with \n{error}.",
|
||||
id=id,
|
||||
) from error
|
||||
except Exception:
|
||||
raise UnsupportedInputs()
|
||||
|
||||
def _check_supported(self, tensor: torch.Tensor, *, id: Tuple[Any, ...]) -> None:
|
||||
if tensor.layout not in {torch.strided, torch.sparse_coo, torch.sparse_csr}: # type: ignore[attr-defined]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user