Dynamo: Fix graph break when iterating over tensor (#94326)

Supports the following with dynamic shapes:
```python
for element in tensor:
    # do stuff with element
```

Approach follows what's done when `call_range()` is invoked with dynamic shape inputs: guard on tensor size and continue tracing with a real size value from `dyn_dim0_size.evaluate_expr()`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94326
Approved by: https://github.com/ezyang
This commit is contained in:
Joel Schlosser 2023-02-08 11:33:19 -05:00 committed by PyTorch MergeBot
parent 7bfc59993d
commit b5ef37b9a4
3 changed files with 52 additions and 17 deletions

View File

@ -60,20 +60,6 @@ unittest.expectedFailure(
# Cannot call sizes() on tensor with symbolic sizes/strides
)
# DynamicShapesExportTests
unittest.expectedFailure(
DynamicShapesExportTests.test_export_with_constant_list_nonzero_dynamic_shapes
)
unittest.expectedFailure(
DynamicShapesExportTests.test_export_with_constant_list_nonzero_free_function_dynamic_shapes
)
unittest.expectedFailure(
DynamicShapesExportTests.test_export_with_constant_tuple_nonzero_dynamic_shapes
)
unittest.expectedFailure(
DynamicShapesExportTests.test_export_with_constant_tuple_nonzero_dynamic_shapes
)
# DynamicShapesSubGraphTests
unittest.expectedFailure(

View File

@ -354,6 +354,17 @@ class MiscTests(torch._dynamo.test_case.TestCase):
r2 = opt_fn(i)
self.assertTrue(same(r1, r2))
def test_tensor_iter(self):
def fn(x):
for y in x:
y.add_(1.0)
return y
# expect extra size node for dynamic
torch._dynamo.testing.standard_test(
self, fn, 1, expected_ops=20, expected_ops_dynamic=21
)
def test_empty_list(self):
def fn(x, ll):
if len(ll) == 0 and not ll and ll is not None:
@ -3659,6 +3670,36 @@ class MiscTests(torch._dynamo.test_case.TestCase):
"tensor 'x' size mismatch at index 0. expected 2, actual 3",
)
def test_guard_failure_fn_tensor_iter(self):
def fn(x):
for y in x:
y.add_(1.0)
return y
guard_failure = None
def guard_failures(failure):
nonlocal guard_failure
guard_failure = failure
opt_fn = torch._dynamo.optimize(
"eager", nopython=True, guard_fail_fn=guard_failures
)(fn)
args1 = torch.randn(10, 10)
out = fn(args1)
opt_out = opt_fn(args1)
self.assertTrue(same(out, opt_out))
args2 = torch.randn(9, 10)
out = fn(args2)
opt_out = opt_fn(args2)
self.assertTrue(same(out, opt_out))
# guard is expected for both static and dynamic shapes
self.assertTrue(guard_failure is not None)
self.assertEqual(guard_failure[0], "len(x) == 10")
def test_restore_graphstate(self):
# This function does some guard accumulation,
# and then rolls back due to control flow.

View File

@ -217,15 +217,23 @@ class TensorVariable(VariableTracker):
return result
def has_unpack_var_sequence(self, tx):
return (self.size is not None and len(self.size) > 0) or (
self.size is None and config.dynamic_shapes
)
def unpack_var_sequence(self, tx, idxes=None):
from .builder import wrap_fx_proxy
options = VariableTracker.propagate(self)
if idxes is None:
if self.size:
idxes = range(self.size[0])
length = self.size[0]
else:
return super(TensorVariable, self).unpack_var_sequence(tx)
options = VariableTracker.propagate(self)
dyn_length = self.call_method(tx, "size", [ConstantVariable(0)], {})
assert isinstance(dyn_length, SymNodeVariable)
length = dyn_length.evaluate_expr(tx.output)
idxes = range(length)
return [wrap_fx_proxy(tx, self.as_proxy()[i], **options) for i in idxes]
def call_method(