mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Don't call item() into torch.scalar_tensor uselessly (#125373)
Fixes https://github.com/pytorch/pytorch/issues/125368 Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/125373 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
5ef50d75f8
commit
2b4fe183db
|
|
@ -4642,6 +4642,19 @@ def fn():
|
|||
res2 = opt_fn(x)
|
||||
self.assertEqual(res, res2)
|
||||
|
||||
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
|
||||
def test_tensor_ctor_list_of_tensor(self):
|
||||
def fn(x):
|
||||
return torch.tensor([x], dtype=torch.int64)
|
||||
|
||||
x = torch.tensor(20)
|
||||
res = fn(x)
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
||||
res2 = opt_fn(x)
|
||||
self.assertEqual(res, res2)
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
def test_tensor_types(self):
|
||||
def fn(dtype, tensor_type):
|
||||
x = torch.empty(4, dtype=dtype)
|
||||
|
|
|
|||
|
|
@ -6359,9 +6359,8 @@ def _infer_scalar_type(obj):
|
|||
# xref: recursive_store in torch/csrc/utils/tensor_new.cpp
|
||||
def _recursive_build(scalarType: torch.dtype, obj: TensorOrNumberLikeType):
|
||||
if isinstance(obj, Tensor) and obj.ndim <= 1:
|
||||
obj = obj.item()
|
||||
# fall through into next case
|
||||
if isinstance(obj, Number):
|
||||
return obj.detach().to(dtype=scalarType, device="cpu", copy=True).view(())
|
||||
elif isinstance(obj, Number):
|
||||
return torch.scalar_tensor(obj, dtype=scalarType)
|
||||
|
||||
seq = obj
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user