[Dynamo] Fix torch.tensor call with tuple (#115713)

Land #114383 on behalf of @ezyang since he is on recharge and this is an high priority issue.
Fix #114231

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115713
Approved by: https://github.com/angelayi, https://github.com/voznesenskym
This commit is contained in:
Yanbo Liang 2023-12-13 04:08:09 +00:00 committed by PyTorch MergeBot
parent 38101e349e
commit 0dad85b402
2 changed files with 5 additions and 1 deletions

View File

@ -354,6 +354,9 @@ class UnspecTests(torch._dynamo.test_case.TestCase):
def f3(v):
return torch.tensor(v.item())
def f4(v):
return torch.tensor((v.item(),))
optimize = torch.compile(backend="aot_eager", fullgraph=True)
r = torch.randn(1)
@ -361,6 +364,7 @@ class UnspecTests(torch._dynamo.test_case.TestCase):
self.assertEqual(f1(r), optimize(f1)(r))
self.assertEqual(f2(r), optimize(f2)(r))
self.assertEqual(f3(r), optimize(f3)(r))
self.assertEqual(f4(r), optimize(f4)(r))
def test_sym_int_conversion(self):
def f(x):

View File

@ -530,7 +530,7 @@ For now, dynamo will explicitly graph break when it encounters user code with th
# NB: This includes UnspecializedPythonVariable
if isinstance(x, (TensorVariable, SymNodeVariable)):
return True
elif isinstance(x, ListVariable):
elif isinstance(x, (ListVariable, TupleVariable)):
return any(check_any_unspec(y) for y in x.items)
# TODO: there maybe other recursive structures you need to
# check