[dynamo] Add __iter__ for iterable VariableTrackers (#166349)

This is in preparation for implementing iter with a polyfill

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166349
Approved by: https://github.com/guilhermeleobas
This commit is contained in:
Rob Timpe 2025-10-28 20:52:19 +00:00 committed by PyTorch MergeBot
parent 59ddfb69a7
commit 3d4a2d8a93
11 changed files with 152 additions and 14 deletions

View File

@ -1623,6 +1623,12 @@ class DictMethodsTests(torch._dynamo.test_case.TestCase):
self.assertNotEqual(self.thetype, other)
self.assertTrue(self.thetype is not other, f"{self.thetype=}, {other=}")
@make_dynamo_test
def test_dict___iter__(self):
d = self.thetype({1: 2})
it = d.__iter__()
self.assertEqual(next(it), 1)
class DictSubclassMethodsTests(DictMethodsTests):
thetype = SimpleDict

View File

@ -168,6 +168,14 @@ class TupleTests(torch._dynamo.test_case.TestCase):
self.assertRaises(TypeError, p.__contains__)
self.assertRaises(TypeError, p.__contains__, 1, 2)
@make_dynamo_test
def test___iter__(self):
p = self.thetype([1])
it = p.__iter__()
self.assertEqual(next(it), 1)
it = p.__iter__().__iter__()
self.assertEqual(next(it), 1)
class ListTests(TupleTests):
# List methods

View File

@ -1272,6 +1272,20 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
r2 = opt_fn(d)
self.assertEqual(r1, r2)
def test_tensor__iter__(self):
def fn(x):
it = x.__iter__()
for y in it:
y.add_(1.0)
return y
torch._dynamo.testing.standard_test(
self,
fn,
1,
expected_ops=20,
)
def test_tensor_iter(self):
def fn(x):
for y in x:
@ -1961,6 +1975,15 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
self.assertTrue(same(res2, torch.ones(2)))
self.assertTrue(same(res3, torch.ones(3)))
def test_range___iter__(self):
def func(x):
it = range(3).__iter__()
return x + next(it)
opt_func = torch.compile(func, backend="eager", fullgraph=True)
x = torch.randn(3)
self.assertTrue(same(func(x), opt_func(x)))
def test_range_iter_side_effects(self):
@torch.compile(backend="eager", fullgraph=True)
def run(x, it):
@ -9608,6 +9631,18 @@ def ___make_guard_fn():
self.assertEqual(msg, "shape torch.Size([8, 8]) batch size 1")
self.assertEqual(res, img1 + torch.sin(img1))
def test_str___iter__(self):
def fn(x):
s = "a"
if next(s.__iter__()) == "a":
return x + 1
else:
return x
x = torch.randn(3)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
self.assertEqual(fn(x), opt_fn(x))
def test_str_format_return2(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(img):

View File

@ -2763,6 +2763,22 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
self.assertEqual(eager_res, optim_res)
self.assertEqual(cnt.frame_count, 1)
def test_specialized_module___iter__(self):
ml = torch.nn.ModuleList(
[
torch.nn.Linear(10, 10),
]
)
ml.torchdynamo_force_dynamic = False
def f(x):
it = ml.__iter__()
return next(it)(x)
opt_f = torch.compile(f, backend="eager", fullgraph=True)
x = torch.randn(10)
self.assertEqual(f(x), opt_f(x))
def test_module_dict_iter_keys(self):
class MyModule(torch.nn.Module):
def __init__(self) -> None:

View File

@ -23,7 +23,7 @@ from ..utils import (
np,
raise_args_mismatch,
)
from .base import VariableTracker
from .base import ValueMutationNew, VariableTracker
if TYPE_CHECKING:
@ -168,6 +168,14 @@ its type to `common_constant_types`.
return ConstantVariable.create(self.value.join(arg_const))
except NotImplementedError:
return super().call_method(tx, name, args, kwargs)
elif name == "__iter__" and istype(self.value, str):
# this could be some generic iterator to avoid the circular import,
# but ListIterator does what we want
from .lists import ListIteratorVariable
return ListIteratorVariable(
self.unpack_var_sequence(tx), mutation_type=ValueMutationNew()
)
if any(isinstance(x, SymNodeVariable) for x in args):
# Promote to SymNodeVariable for operations involving dynamic shapes.

View File

@ -34,7 +34,7 @@ from .. import graph_break_hints, polyfills, variables
from ..bytecode_transformation import create_call_function, create_instruction
from ..exc import raise_observed_exception, unimplemented_v2
from ..guards import GuardBuilder, install_guard
from ..source import is_from_local_source
from ..source import is_constant_source, is_from_local_source
from ..utils import (
cmp_name_to_op_mapping,
dict_items,
@ -46,6 +46,7 @@ from ..utils import (
)
from .base import ValueMutationNew, VariableTracker
from .constant import ConstantVariable
from .lists import ListIteratorVariable
if TYPE_CHECKING:
@ -779,6 +780,12 @@ class ConstDictVariable(VariableTracker):
elif name == "__ior__":
self.call_method(tx, "update", args, kwargs)
return self
elif name == "__iter__":
if self.source and not is_constant_source(self.source):
tx.output.guard_on_key_order.add(self.source)
return ListIteratorVariable(
self.unpack_var_sequence(tx), mutation_type=ValueMutationNew()
)
else:
return super().call_method(tx, name, args, kwargs)
@ -787,12 +794,16 @@ class ConstDictVariable(VariableTracker):
return [x.vt for x in self.items.keys()]
def call_obj_hasattr(self, tx, name):
# dict not allow setting arbitrary attributes. To check for hasattr, we can just check the __dict__ of the dict.
# OrderedDict though requires side effects tracking because it supports arbitrary setattr.
if self.user_cls is dict:
if name in self.user_cls.__dict__:
# dict not allow setting arbitrary attributes. OrderedDict and
# defaultdict allow arbitrary setattr, but not deletion of default attrs
if any(
self.user_cls is t
for t in (dict, collections.OrderedDict, collections.defaultdict)
):
if hasattr(self.user_cls, name):
return ConstantVariable.create(True)
return ConstantVariable.create(False)
if self.user_cls is dict:
return ConstantVariable.create(False)
msg = f"hasattr on {self.user_cls} is not supported"
unimplemented_v2(
@ -879,6 +890,13 @@ class MappingProxyVariable(VariableTracker):
)
return self.dv_dict.call_method(tx, name, args, kwargs)
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> "VariableTracker":
if self.python_type() is types.MappingProxyType:
return ConstantVariable.create(name in types.MappingProxyType.__dict__)
return super().call_obj_hasattr(tx, name)
class NNModuleHooksDictVariable(ConstDictVariable):
# Special class to avoid adding any guards on the nn module hook ids.
@ -1388,6 +1406,10 @@ class DictViewVariable(VariableTracker):
) -> "VariableTracker":
if name == "__len__":
return self.dv_dict.call_method(tx, name, args, kwargs)
elif name == "__iter__":
return ListIteratorVariable(
self.view_items_vt, mutation_type=ValueMutationNew()
)
return super().call_method(tx, name, args, kwargs)

View File

@ -252,6 +252,26 @@ class IteratorVariable(VariableTracker):
def has_force_unpack_var_sequence(self, tx) -> bool:
return True
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> "VariableTracker":
if name == "__iter__" or name == "__next__":
return variables.ConstantVariable.create(True)
super().call_obj_hasattr(tx, name)
def call_method(
self,
tx: "InstructionTranslator",
name,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if name == "__iter__":
return self
elif name == "__next__":
return self.next_variable(tx)
return super().call_method(tx, name, args, kwargs)
class ObjectIteratorVariable(IteratorVariable):
"""

View File

@ -294,6 +294,10 @@ class BaseListVariable(VariableTracker):
[variables.BuiltinVariable(cmp_name_to_op_mapping[name]), left, right],
{},
)
elif name == "__iter__":
return ListIteratorVariable(
list(self.items), mutation_type=ValueMutationNew()
)
return super().call_method(tx, name, args, kwargs)
@ -472,9 +476,9 @@ class RangeVariable(BaseListVariable):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> "VariableTracker":
if self.python_type() is not range:
return super().call_obj_hasattr(tx, name)
return variables.ConstantVariable.create(hasattr(range(0), name))
if self.python_type() is range:
return variables.ConstantVariable.create(name in range.__dict__)
return super().call_obj_hasattr(tx, name)
def range_equals(self, other: "RangeVariable"):
r0, r1 = self, other
@ -1064,6 +1068,13 @@ class DequeVariable(CommonListMethodsVariable):
self.items[:] = self.items[slice_within_maxlen]
return result
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> "VariableTracker":
if self.python_type() is collections.deque:
return variables.ConstantVariable.create(name in collections.deque.__dict__)
return super().call_obj_hasattr(tx, name)
class TupleVariable(BaseListVariable):
def python_type(self):

View File

@ -796,6 +796,10 @@ class NNModuleVariable(VariableTracker):
f"{len(args)} args and {len(kwargs)} kwargs",
)
return ConstantVariable.create(len(module))
elif name == "__iter__":
return ListIteratorVariable(
self.unpack_var_sequence(tx), mutation_type=ValueMutationNew()
)
elif (
name == "__contains__"
and isinstance(module, (torch.nn.ModuleDict, torch.nn.ParameterDict))

View File

@ -66,9 +66,9 @@ from ..utils import (
set_example_value,
tensortype_to_dtype,
)
from .base import AttributeMutationNew, VariableTracker
from .base import AttributeMutationNew, ValueMutationNew, VariableTracker
from .constant import ConstantVariable
from .lists import SizeVariable
from .lists import ListIteratorVariable, SizeVariable
from .user_defined import UserDefinedClassVariable
@ -427,7 +427,7 @@ class TensorVariable(VariableTracker):
# Today, var_getattr returns GetAttrVariable for both non-existent
# attributes and existing attributes. This is a bug and requires more
# deep dive.
if name in ("size", "stride"):
if name in ("size", "stride", "__iter__"):
return ConstantVariable(True)
try:
@ -1079,6 +1079,14 @@ class TensorVariable(VariableTracker):
tx = InstructionTranslator.current_tx()
return self.call_method(tx, "size", [ConstantVariable.create(0)], {})
def method___iter__(self):
from ..symbolic_convert import InstructionTranslator
tx = InstructionTranslator.current_tx()
return ListIteratorVariable(
self.unpack_var_sequence(tx), mutation_type=ValueMutationNew()
)
def method_addcmul_(self, tensor1, tensor2, *, value=None):
from ..symbolic_convert import InstructionTranslator
@ -1612,7 +1620,7 @@ class NumpyNdarrayVariable(TensorVariable):
),
hints=[*graph_break_hints.FUNDAMENTAL],
)
if name in ["__len__", "size", "tolist"]:
if name in ["__len__", "size", "tolist", "__iter__"]:
# delegate back to TensorVariable
return super().call_method(tx, name, args, kwargs)
if name in ("tostring", "tobytes", "__delattr__"):