[dynamo bug burndown] update tensor creation to support sequences of tensors (#120872)

Fixes https://github.com/pytorch/pytorch/issues/120645

`_internal_new_from_data` calls `_recursive_build`, but we run into an error such as the cases.
```
Failed running call_function <function tensor at 0xDEADBEEF>:
scalar_tensor(): argument (position 1) must be Number, not FakeTensor

# e.g. cases
1. [FakeTensor(..., size=(20, 1), dtype=torch.float64), ..., FakeTensor(..., size=(20, 1), dtype=torch.float64)]
- Here, we call _recursive_build(sizes=[4] ...) which hits the base case `if dim == ndim:` in the 2nd level of recursion.
- So, we try to return `scalar_tensor(FakeTensor)`
2. [[(FakeTensor(..., size=(1,), dtype=torch.int64), FakeTensor(..., size=(), dtype=torch.int64)]]

# site note: when can size = ()? Probably from scalar_tensor.
>>> torch.scalar_tensor(1).shape
torch.Size([])
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120872
Approved by: https://github.com/ezyang
This commit is contained in:
Colin Peppler 2024-03-01 14:50:11 -08:00 committed by PyTorch MergeBot
parent a3b81666b1
commit 06fe6ed82b
8 changed files with 8 additions and 15 deletions

View File

@ -6332,21 +6332,15 @@ def _infer_scalar_type(obj):
# Analogous to recursive_store # Analogous to recursive_store
# xref: recursive_store in torch/csrc/utils/tensor_new.cpp # xref: recursive_store in torch/csrc/utils/tensor_new.cpp
def _recursive_build(sizes, dim, scalarType, obj): def _recursive_build(scalarType: torch.dtype, obj: TensorOrNumberLikeType):
ndim = len(sizes) if isinstance(obj, Tensor) and obj.ndim <= 1:
assert dim <= ndim obj = obj.item()
if dim == ndim: # fall through into next case
if isinstance(obj, Number):
return torch.scalar_tensor(obj, dtype=scalarType) return torch.scalar_tensor(obj, dtype=scalarType)
n = sizes[dim]
seq = obj seq = obj
seq_size = len(seq) return torch.stack([_recursive_build(scalarType, item) for item in seq])
if seq_size != n:
raise ValueError(
f"expected sequence of length {n} at dim {dim} (got {seq_size})"
)
return torch.stack(
[_recursive_build(sizes, dim + 1, scalarType, item) for item in seq]
)
# xref: internal_new_from_data in torch/csrc/utils/tensor_new.cpp # xref: internal_new_from_data in torch/csrc/utils/tensor_new.cpp
@ -6383,7 +6377,6 @@ def _internal_new_from_data(
# TODO: test for numpy input with PyArray_Check # TODO: test for numpy input with PyArray_Check
device = device_opt if device_opt is not None else options["device"] device = device_opt if device_opt is not None else options["device"]
sizes = _compute_sizes(data, scalar_type)
inferred_scalar_type = _infer_scalar_type(data) if type_inference else scalar_type inferred_scalar_type = _infer_scalar_type(data) if type_inference else scalar_type
# NB: Don't need to avoid tracing, as we aren't going to do any manual # NB: Don't need to avoid tracing, as we aren't going to do any manual
@ -6398,7 +6391,7 @@ def _internal_new_from_data(
# of a freshly allocated CPU tensor. Here, we're going to do an # of a freshly allocated CPU tensor. Here, we're going to do an
# alternate, heinously slow implementation: turn each individual # alternate, heinously slow implementation: turn each individual
# scalar into a tensor, and then repeatedly cat them together # scalar into a tensor, and then repeatedly cat them together
tensor = _recursive_build(sizes, 0, inferred_scalar_type, data) tensor = _recursive_build(inferred_scalar_type, data)
tensor = tensor.to(device, inferred_scalar_type, non_blocking=False, copy=False) tensor = tensor.to(device, inferred_scalar_type, non_blocking=False, copy=False)