Early return in _recursive_build if obj is a Tensor (#125639)

Fix issue #125551

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125639
Approved by: https://github.com/ezyang
This commit is contained in:
Guilherme Leobas 2024-05-16 17:08:51 -03:00 committed by PyTorch MergeBot
parent 7e166e8057
commit 402170b22f
6 changed files with 45 additions and 3 deletions

View File

@ -394,6 +394,36 @@ class UnspecTests(torch._dynamo.test_case.TestCase):
self.assertEqual(f3(r), optimize(f3)(r))
self.assertEqual(f4(r), optimize(f4)(r))
def test_to_tensor(self):
def f1():
a = np.random.uniform(low=-1, high=1, size=(20, 1))
return torch.tensor([a, a, a, a], dtype=torch.float64, device="cpu")
def f2():
a = torch.tensor([[[123]]])
return torch.tensor([a, a])
def f3():
a = torch.tensor(123)
return torch.tensor([a, a])
def f4():
a = torch.tensor(123)
b = torch.tensor([[[456]]])
return torch.tensor([a, b])
def f5():
a = np.array([1, 2])
return torch.tensor([a, a])
optimize = torch.compile(backend="aot_eager", fullgraph=True)
self.assertEqual(f1().shape, optimize(f1)().shape)
self.assertEqual(f2(), optimize(f2)())
self.assertEqual(f3(), optimize(f3)())
self.assertEqual(f4(), optimize(f4)())
self.assertEqual(f5(), optimize(f5)())
def test_sym_int_conversion(self):
def f(x):
y = x.size(0)

View File

@ -3830,7 +3830,7 @@ def _check_stack_inputs(tensors: TensorSequenceType) -> None:
entry_shape = tensors[0].shape
for i in range(1, len(tensors)):
assert tensors[i].shape == entry_shape, (
f"stack expects each tensor to be equal size, but got {entry_shape} at entry 0"
f"stack expects each tensor to be equal size, but got {entry_shape} at entry 0 "
f"and {tensors[i].shape} at entry {i}"
)
@ -6358,12 +6358,24 @@ def _infer_scalar_type(obj):
# Analogous to recursive_store
# xref: recursive_store in torch/csrc/utils/tensor_new.cpp
def _recursive_build(scalarType: torch.dtype, obj: TensorOrNumberLikeType):
if isinstance(obj, Tensor) and obj.ndim <= 1:
def _recursive_build(
scalarType: torch.dtype, obj: Union[TensorOrNumberLikeType, TensorSequenceType]
):
if isinstance(obj, Tensor) and obj.numel() == 1:
return obj.detach().to(dtype=scalarType, device="cpu", copy=True).view(())
elif isinstance(obj, Tensor):
# It is invalid to call ".tensor([...])" with a non-scalar tensor in eager mode
# >>> torch.tensor([torch.randn(2)])
# ValueError: only one element tensors can be converted to Python scalars
#
# But it is possible with a NumPy array
# >>> torch.tensor([np.random.uniform(size=(2,))]).shape
# torch.Size([1, 2])
return obj.detach().to(dtype=scalarType, device="cpu", copy=True)
elif isinstance(obj, Number):
return torch.scalar_tensor(obj, dtype=scalarType)
# seq can be a list of tensors
seq = obj
return torch.stack([_recursive_build(scalarType, item) for item in seq])