fix TensorLikePair origination (#70304)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70304

Without this patch `TensorLikePair` will try to instantiate everything although it should only do so for tensor-likes. This is problematic if it is used before a different pair that would be able to handle the inputs but never gets to do so, because `TensorLikePair` bails out before.

```python
from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair

assert_equal("a", "a", pair_types=(TensorLikePair, ObjectPair))
```

```
ValueError: Constructing a tensor from <class 'str'> failed with
new(): invalid data type 'str'.
```

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D33542995

Pulled By: mruberry

fbshipit-source-id: 77a5cc0abad44356c3ec64c7ec46e84d166ab2dd
This commit is contained in:
Philip Meier 2022-01-12 06:40:45 -08:00 committed by Facebook GitHub Bot
parent 49a5b33a74
commit 928ca95ff0
2 changed files with 5 additions and 9 deletions

View File

@ -745,7 +745,7 @@ class TestAssertClose(TestCase):
expected = "0"
for fn in assert_close_with_inputs(actual, expected):
with self.assertRaisesRegex(ValueError, str(type(actual))):
with self.assertRaisesRegex(TypeError, str(type(actual))):
fn()
def test_mismatching_shape(self):

View File

@ -570,23 +570,19 @@ class TensorLikePair(Pair):
if not allow_subclasses and type(actual) is not type(expected):
raise UnsupportedInputs()
actual, expected = [self._to_tensor(input, id=id) for input in (actual, expected)]
actual, expected = [self._to_tensor(input) for input in (actual, expected)]
for tensor in (actual, expected):
self._check_supported(tensor, id=id)
return actual, expected
def _to_tensor(self, tensor_like: Any, *, id: Tuple[Any, ...]) -> torch.Tensor:
def _to_tensor(self, tensor_like: Any) -> torch.Tensor:
if isinstance(tensor_like, torch.Tensor):
return tensor_like
try:
return torch.as_tensor(tensor_like)
except Exception as error:
raise ErrorMeta(
ValueError,
f"Constructing a tensor from {type(tensor_like)} failed with \n{error}.",
id=id,
) from error
except Exception:
raise UnsupportedInputs()
def _check_supported(self, tensor: torch.Tensor, *, id: Tuple[Any, ...]) -> None:
if tensor.layout not in {torch.strided, torch.sparse_coo, torch.sparse_csr}: # type: ignore[attr-defined]