mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Early return in _recursive_build if obj is a Tensor (#125639)
Fix issue #125551 Pull Request resolved: https://github.com/pytorch/pytorch/pull/125639 Approved by: https://github.com/ezyang
This commit is contained in:
parent
7e166e8057
commit
402170b22f
|
|
@ -394,6 +394,36 @@ class UnspecTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertEqual(f3(r), optimize(f3)(r))
|
||||
self.assertEqual(f4(r), optimize(f4)(r))
|
||||
|
||||
def test_to_tensor(self):
|
||||
def f1():
|
||||
a = np.random.uniform(low=-1, high=1, size=(20, 1))
|
||||
return torch.tensor([a, a, a, a], dtype=torch.float64, device="cpu")
|
||||
|
||||
def f2():
|
||||
a = torch.tensor([[[123]]])
|
||||
return torch.tensor([a, a])
|
||||
|
||||
def f3():
|
||||
a = torch.tensor(123)
|
||||
return torch.tensor([a, a])
|
||||
|
||||
def f4():
|
||||
a = torch.tensor(123)
|
||||
b = torch.tensor([[[456]]])
|
||||
return torch.tensor([a, b])
|
||||
|
||||
def f5():
|
||||
a = np.array([1, 2])
|
||||
return torch.tensor([a, a])
|
||||
|
||||
optimize = torch.compile(backend="aot_eager", fullgraph=True)
|
||||
|
||||
self.assertEqual(f1().shape, optimize(f1)().shape)
|
||||
self.assertEqual(f2(), optimize(f2)())
|
||||
self.assertEqual(f3(), optimize(f3)())
|
||||
self.assertEqual(f4(), optimize(f4)())
|
||||
self.assertEqual(f5(), optimize(f5)())
|
||||
|
||||
def test_sym_int_conversion(self):
|
||||
def f(x):
|
||||
y = x.size(0)
|
||||
|
|
|
|||
|
|
@ -3830,7 +3830,7 @@ def _check_stack_inputs(tensors: TensorSequenceType) -> None:
|
|||
entry_shape = tensors[0].shape
|
||||
for i in range(1, len(tensors)):
|
||||
assert tensors[i].shape == entry_shape, (
|
||||
f"stack expects each tensor to be equal size, but got {entry_shape} at entry 0"
|
||||
f"stack expects each tensor to be equal size, but got {entry_shape} at entry 0 "
|
||||
f"and {tensors[i].shape} at entry {i}"
|
||||
)
|
||||
|
||||
|
|
@ -6358,12 +6358,24 @@ def _infer_scalar_type(obj):
|
|||
|
||||
# Analogous to recursive_store
|
||||
# xref: recursive_store in torch/csrc/utils/tensor_new.cpp
|
||||
def _recursive_build(scalarType: torch.dtype, obj: TensorOrNumberLikeType):
|
||||
if isinstance(obj, Tensor) and obj.ndim <= 1:
|
||||
def _recursive_build(
|
||||
scalarType: torch.dtype, obj: Union[TensorOrNumberLikeType, TensorSequenceType]
|
||||
):
|
||||
if isinstance(obj, Tensor) and obj.numel() == 1:
|
||||
return obj.detach().to(dtype=scalarType, device="cpu", copy=True).view(())
|
||||
elif isinstance(obj, Tensor):
|
||||
# It is invalid to call ".tensor([...])" with a non-scalar tensor in eager mode
|
||||
# >>> torch.tensor([torch.randn(2)])
|
||||
# ValueError: only one element tensors can be converted to Python scalars
|
||||
#
|
||||
# But it is possible with a NumPy array
|
||||
# >>> torch.tensor([np.random.uniform(size=(2,))]).shape
|
||||
# torch.Size([1, 2])
|
||||
return obj.detach().to(dtype=scalarType, device="cpu", copy=True)
|
||||
elif isinstance(obj, Number):
|
||||
return torch.scalar_tensor(obj, dtype=scalarType)
|
||||
|
||||
# seq can be a list of tensors
|
||||
seq = obj
|
||||
return torch.stack([_recursive_build(scalarType, item) for item in seq])
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user