diff --git a/test/dynamo/test_unspec.py b/test/dynamo/test_unspec.py index 9628720ea7f..7110155d169 100644 --- a/test/dynamo/test_unspec.py +++ b/test/dynamo/test_unspec.py @@ -354,6 +354,9 @@ class UnspecTests(torch._dynamo.test_case.TestCase): def f3(v): return torch.tensor(v.item()) + def f4(v): + return torch.tensor((v.item(),)) + optimize = torch.compile(backend="aot_eager", fullgraph=True) r = torch.randn(1) @@ -361,6 +364,7 @@ class UnspecTests(torch._dynamo.test_case.TestCase): self.assertEqual(f1(r), optimize(f1)(r)) self.assertEqual(f2(r), optimize(f2)(r)) self.assertEqual(f3(r), optimize(f3)(r)) + self.assertEqual(f4(r), optimize(f4)(r)) def test_sym_int_conversion(self): def f(x): diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index d4cbbee5082..d8ffeb76065 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -530,7 +530,7 @@ For now, dynamo will explicitly graph break when it encounters user code with th # NB: This includes UnspecializedPythonVariable if isinstance(x, (TensorVariable, SymNodeVariable)): return True - elif isinstance(x, ListVariable): + elif isinstance(x, (ListVariable, TupleVariable)): return any(check_any_unspec(y) for y in x.items) # TODO: there maybe other recursive structures you need to # check