mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Dynamo: Fix graph break when iterating over tensor (#94326)
Supports the following with dynamic shapes:
```python
for element in tensor:
# do stuff with element
```
Approach follows what's done when `call_range()` is invoked with dynamic shape inputs: guard on tensor size and continue tracing with a real size value from `dyn_dim0_size.evaluate_expr()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94326
Approved by: https://github.com/ezyang
This commit is contained in:
parent
7bfc59993d
commit
b5ef37b9a4
|
|
@ -60,20 +60,6 @@ unittest.expectedFailure(
|
|||
# Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
)
|
||||
|
||||
# DynamicShapesExportTests
|
||||
unittest.expectedFailure(
|
||||
DynamicShapesExportTests.test_export_with_constant_list_nonzero_dynamic_shapes
|
||||
)
|
||||
unittest.expectedFailure(
|
||||
DynamicShapesExportTests.test_export_with_constant_list_nonzero_free_function_dynamic_shapes
|
||||
)
|
||||
unittest.expectedFailure(
|
||||
DynamicShapesExportTests.test_export_with_constant_tuple_nonzero_dynamic_shapes
|
||||
)
|
||||
unittest.expectedFailure(
|
||||
DynamicShapesExportTests.test_export_with_constant_tuple_nonzero_dynamic_shapes
|
||||
)
|
||||
|
||||
|
||||
# DynamicShapesSubGraphTests
|
||||
unittest.expectedFailure(
|
||||
|
|
|
|||
|
|
@ -354,6 +354,17 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
r2 = opt_fn(i)
|
||||
self.assertTrue(same(r1, r2))
|
||||
|
||||
def test_tensor_iter(self):
|
||||
def fn(x):
|
||||
for y in x:
|
||||
y.add_(1.0)
|
||||
return y
|
||||
|
||||
# expect extra size node for dynamic
|
||||
torch._dynamo.testing.standard_test(
|
||||
self, fn, 1, expected_ops=20, expected_ops_dynamic=21
|
||||
)
|
||||
|
||||
def test_empty_list(self):
|
||||
def fn(x, ll):
|
||||
if len(ll) == 0 and not ll and ll is not None:
|
||||
|
|
@ -3659,6 +3670,36 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
"tensor 'x' size mismatch at index 0. expected 2, actual 3",
|
||||
)
|
||||
|
||||
def test_guard_failure_fn_tensor_iter(self):
|
||||
def fn(x):
|
||||
for y in x:
|
||||
y.add_(1.0)
|
||||
return y
|
||||
|
||||
guard_failure = None
|
||||
|
||||
def guard_failures(failure):
|
||||
nonlocal guard_failure
|
||||
guard_failure = failure
|
||||
|
||||
opt_fn = torch._dynamo.optimize(
|
||||
"eager", nopython=True, guard_fail_fn=guard_failures
|
||||
)(fn)
|
||||
|
||||
args1 = torch.randn(10, 10)
|
||||
out = fn(args1)
|
||||
opt_out = opt_fn(args1)
|
||||
self.assertTrue(same(out, opt_out))
|
||||
|
||||
args2 = torch.randn(9, 10)
|
||||
out = fn(args2)
|
||||
opt_out = opt_fn(args2)
|
||||
self.assertTrue(same(out, opt_out))
|
||||
|
||||
# guard is expected for both static and dynamic shapes
|
||||
self.assertTrue(guard_failure is not None)
|
||||
self.assertEqual(guard_failure[0], "len(x) == 10")
|
||||
|
||||
def test_restore_graphstate(self):
|
||||
# This function does some guard accumulation,
|
||||
# and then rolls back due to control flow.
|
||||
|
|
|
|||
|
|
@ -217,15 +217,23 @@ class TensorVariable(VariableTracker):
|
|||
|
||||
return result
|
||||
|
||||
def has_unpack_var_sequence(self, tx):
|
||||
return (self.size is not None and len(self.size) > 0) or (
|
||||
self.size is None and config.dynamic_shapes
|
||||
)
|
||||
|
||||
def unpack_var_sequence(self, tx, idxes=None):
|
||||
from .builder import wrap_fx_proxy
|
||||
|
||||
options = VariableTracker.propagate(self)
|
||||
if idxes is None:
|
||||
if self.size:
|
||||
idxes = range(self.size[0])
|
||||
length = self.size[0]
|
||||
else:
|
||||
return super(TensorVariable, self).unpack_var_sequence(tx)
|
||||
options = VariableTracker.propagate(self)
|
||||
dyn_length = self.call_method(tx, "size", [ConstantVariable(0)], {})
|
||||
assert isinstance(dyn_length, SymNodeVariable)
|
||||
length = dyn_length.evaluate_expr(tx.output)
|
||||
idxes = range(length)
|
||||
return [wrap_fx_proxy(tx, self.as_proxy()[i], **options) for i in idxes]
|
||||
|
||||
def call_method(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user