From 0dad85b402c2068c76a61d7c1ddcb60767ca1ef7 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 13 Dec 2023 04:08:09 +0000 Subject: [PATCH] [Dynamo] Fix torch.tensor call with tuple (#115713) Land #114383 on behalf of @ezyang since he is on recharge and this is an high priority issue. Fix #114231 Pull Request resolved: https://github.com/pytorch/pytorch/pull/115713 Approved by: https://github.com/angelayi, https://github.com/voznesenskym --- test/dynamo/test_unspec.py | 4 ++++ torch/_dynamo/variables/torch.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_unspec.py b/test/dynamo/test_unspec.py index 9628720ea7f..7110155d169 100644 --- a/test/dynamo/test_unspec.py +++ b/test/dynamo/test_unspec.py @@ -354,6 +354,9 @@ class UnspecTests(torch._dynamo.test_case.TestCase): def f3(v): return torch.tensor(v.item()) + def f4(v): + return torch.tensor((v.item(),)) + optimize = torch.compile(backend="aot_eager", fullgraph=True) r = torch.randn(1) @@ -361,6 +364,7 @@ class UnspecTests(torch._dynamo.test_case.TestCase): self.assertEqual(f1(r), optimize(f1)(r)) self.assertEqual(f2(r), optimize(f2)(r)) self.assertEqual(f3(r), optimize(f3)(r)) + self.assertEqual(f4(r), optimize(f4)(r)) def test_sym_int_conversion(self): def f(x): diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index d4cbbee5082..d8ffeb76065 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -530,7 +530,7 @@ For now, dynamo will explicitly graph break when it encounters user code with th # NB: This includes UnspecializedPythonVariable if isinstance(x, (TensorVariable, SymNodeVariable)): return True - elif isinstance(x, ListVariable): + elif isinstance(x, (ListVariable, TupleVariable)): return any(check_any_unspec(y) for y in x.items) # TODO: there maybe other recursive structures you need to # check