[dynamo] inlining into __iter__ of user defined object (#119243)

Fixes #119198.

This PR make dynamo inline `__iter__` of a user defined object instead of creating a graph break. Also added a new test, which shows:
1. the loop is unrolled
2. the length of the loop is guarded when inlining `__iter__`
```python
class Mod:
    def __init__(self):
        self.a = [torch.randn(2, 2), torch.randn(2, 2)]

    def __iter__(self):
        return iter(self.a)

def f(mod):
    ret = []
    for x in mod:
        ret.append(x + 1)
    return ret
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119243
Approved by: https://github.com/jansel
This commit is contained in:
ydwu4 2024-02-07 10:04:58 -08:00 committed by PyTorch MergeBot
parent b181e52a8f
commit b251bca205
2 changed files with 47 additions and 1 deletions

View File

@ -852,6 +852,42 @@ class MiscTests(torch._dynamo.test_case.TestCase):
else:
self.assertExpectedInline(counts.op_count, """4""")
def test_user_defined_iter(self):
class Mod:
def __init__(self):
self.a = [torch.randn(2, 2), torch.randn(2, 2)]
def __iter__(self):
return iter(self.a)
def f(mod):
ret = []
for x in mod:
ret.append(x + 1)
return ret
mod = Mod()
counts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(counts, nopython=True)(f)
ref = f(mod)
res = opt_fn(mod)
res = opt_fn(mod)
res = opt_fn(mod)
res = opt_fn(mod)
self.assertTrue(same(ref, res))
self.assertEqual(counts.frame_count, 1)
mod.a.append(torch.randn(2, 2))
# `for x in mod` is inlined, where iter(m.a) creates a guard on the list length of m.a
# Mutating length of mod.a causes a re-compilation.
ref2 = f(mod)
res2 = opt_fn(mod)
res2 = opt_fn(mod)
res2 = opt_fn(mod)
res2 = opt_fn(mod)
self.assertTrue(same(ref2, res2))
self.assertEqual(counts.frame_count, 2)
def test_compare_shapes_eq(self):
def compare_shapes(a, b, to_list):
x = list(a.unsqueeze(-1).shape) if to_list else a.shape

View File

@ -930,7 +930,17 @@ class BuiltinVariable(VariableTracker):
mutable_local=MutableLocal(),
)
call_iter = _call_iter_tuple_list
def call_iter(self, tx, obj, *args, **kwargs):
# Handle the case where we are iterating over a tuple, list or iterator
ret = self._call_iter_tuple_list(tx, obj, *args, **kwargs)
if ret is None:
# If the object doesn't implement a __iter__ method, it will be an error in eager mode when calling iter on it anyway.
# If the object implements a __iter__ method, inlining effectively forwards the call to another iter call
# (e.g. when __iter__ just returns iter(self.list)) or return a user-defined iterator.
return obj.call_method(tx, "__iter__", args, kwargs)
return ret
call_tuple = _call_iter_tuple_list
call_list = _call_iter_tuple_list