mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[Dynamo] Fix torch.tensor call with tuple (#115713)
Land #114383 on behalf of @ezyang since he is on recharge and this is an high priority issue. Fix #114231 Pull Request resolved: https://github.com/pytorch/pytorch/pull/115713 Approved by: https://github.com/angelayi, https://github.com/voznesenskym
This commit is contained in:
parent
38101e349e
commit
0dad85b402
|
|
@ -354,6 +354,9 @@ class UnspecTests(torch._dynamo.test_case.TestCase):
|
|||
def f3(v):
|
||||
return torch.tensor(v.item())
|
||||
|
||||
def f4(v):
|
||||
return torch.tensor((v.item(),))
|
||||
|
||||
optimize = torch.compile(backend="aot_eager", fullgraph=True)
|
||||
|
||||
r = torch.randn(1)
|
||||
|
|
@ -361,6 +364,7 @@ class UnspecTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertEqual(f1(r), optimize(f1)(r))
|
||||
self.assertEqual(f2(r), optimize(f2)(r))
|
||||
self.assertEqual(f3(r), optimize(f3)(r))
|
||||
self.assertEqual(f4(r), optimize(f4)(r))
|
||||
|
||||
def test_sym_int_conversion(self):
|
||||
def f(x):
|
||||
|
|
|
|||
|
|
@ -530,7 +530,7 @@ For now, dynamo will explicitly graph break when it encounters user code with th
|
|||
# NB: This includes UnspecializedPythonVariable
|
||||
if isinstance(x, (TensorVariable, SymNodeVariable)):
|
||||
return True
|
||||
elif isinstance(x, ListVariable):
|
||||
elif isinstance(x, (ListVariable, TupleVariable)):
|
||||
return any(check_any_unspec(y) for y in x.items)
|
||||
# TODO: there maybe other recursive structures you need to
|
||||
# check
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user