Don't call item() into torch.scalar_tensor uselessly (#125373)

Fixes https://github.com/pytorch/pytorch/issues/125368

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125373
Approved by: https://github.com/Skylion007
This commit is contained in:
Edward Z. Yang 2024-05-03 21:44:13 -07:00 committed by PyTorch MergeBot
parent 5ef50d75f8
commit 2b4fe183db
3 changed files with 15 additions and 3 deletions

View File

@ -4642,6 +4642,19 @@ def fn():
res2 = opt_fn(x)
self.assertEqual(res, res2)
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
def test_tensor_ctor_list_of_tensor(self):
def fn(x):
return torch.tensor([x], dtype=torch.int64)
x = torch.tensor(20)
res = fn(x)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
res2 = opt_fn(x)
self.assertEqual(res, res2)
self.assertEqual(cnts.frame_count, 1)
def test_tensor_types(self):
def fn(dtype, tensor_type):
x = torch.empty(4, dtype=dtype)

View File

@ -6359,9 +6359,8 @@ def _infer_scalar_type(obj):
# xref: recursive_store in torch/csrc/utils/tensor_new.cpp
def _recursive_build(scalarType: torch.dtype, obj: TensorOrNumberLikeType):
if isinstance(obj, Tensor) and obj.ndim <= 1:
obj = obj.item()
# fall through into next case
if isinstance(obj, Number):
return obj.detach().to(dtype=scalarType, device="cpu", copy=True).view(())
elif isinstance(obj, Number):
return torch.scalar_tensor(obj, dtype=scalarType)
seq = obj