mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[dynamo] inlining into __iter__ of user defined object (#119243)
Fixes #119198. This PR make dynamo inline `__iter__` of a user defined object instead of creating a graph break. Also added a new test, which shows: 1. the loop is unrolled 2. the length of the loop is guarded when inlining `__iter__` ```python class Mod: def __init__(self): self.a = [torch.randn(2, 2), torch.randn(2, 2)] def __iter__(self): return iter(self.a) def f(mod): ret = [] for x in mod: ret.append(x + 1) return ret ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/119243 Approved by: https://github.com/jansel
This commit is contained in:
parent
b181e52a8f
commit
b251bca205
|
|
@ -852,6 +852,42 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
else:
|
||||
self.assertExpectedInline(counts.op_count, """4""")
|
||||
|
||||
def test_user_defined_iter(self):
|
||||
class Mod:
|
||||
def __init__(self):
|
||||
self.a = [torch.randn(2, 2), torch.randn(2, 2)]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.a)
|
||||
|
||||
def f(mod):
|
||||
ret = []
|
||||
for x in mod:
|
||||
ret.append(x + 1)
|
||||
return ret
|
||||
|
||||
mod = Mod()
|
||||
counts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(counts, nopython=True)(f)
|
||||
ref = f(mod)
|
||||
res = opt_fn(mod)
|
||||
res = opt_fn(mod)
|
||||
res = opt_fn(mod)
|
||||
res = opt_fn(mod)
|
||||
self.assertTrue(same(ref, res))
|
||||
self.assertEqual(counts.frame_count, 1)
|
||||
|
||||
mod.a.append(torch.randn(2, 2))
|
||||
# `for x in mod` is inlined, where iter(m.a) creates a guard on the list length of m.a
|
||||
# Mutating length of mod.a causes a re-compilation.
|
||||
ref2 = f(mod)
|
||||
res2 = opt_fn(mod)
|
||||
res2 = opt_fn(mod)
|
||||
res2 = opt_fn(mod)
|
||||
res2 = opt_fn(mod)
|
||||
self.assertTrue(same(ref2, res2))
|
||||
self.assertEqual(counts.frame_count, 2)
|
||||
|
||||
def test_compare_shapes_eq(self):
|
||||
def compare_shapes(a, b, to_list):
|
||||
x = list(a.unsqueeze(-1).shape) if to_list else a.shape
|
||||
|
|
|
|||
|
|
@ -930,7 +930,17 @@ class BuiltinVariable(VariableTracker):
|
|||
mutable_local=MutableLocal(),
|
||||
)
|
||||
|
||||
call_iter = _call_iter_tuple_list
|
||||
def call_iter(self, tx, obj, *args, **kwargs):
|
||||
# Handle the case where we are iterating over a tuple, list or iterator
|
||||
ret = self._call_iter_tuple_list(tx, obj, *args, **kwargs)
|
||||
|
||||
if ret is None:
|
||||
# If the object doesn't implement a __iter__ method, it will be an error in eager mode when calling iter on it anyway.
|
||||
# If the object implements a __iter__ method, inlining effectively forwards the call to another iter call
|
||||
# (e.g. when __iter__ just returns iter(self.list)) or return a user-defined iterator.
|
||||
return obj.call_method(tx, "__iter__", args, kwargs)
|
||||
return ret
|
||||
|
||||
call_tuple = _call_iter_tuple_list
|
||||
call_list = _call_iter_tuple_list
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user