mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Add __iter__ for iterable VariableTrackers (#166349)
This is in preparation for implementing iter with a polyfill Pull Request resolved: https://github.com/pytorch/pytorch/pull/166349 Approved by: https://github.com/guilhermeleobas
This commit is contained in:
parent
59ddfb69a7
commit
3d4a2d8a93
|
|
@ -1623,6 +1623,12 @@ class DictMethodsTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertNotEqual(self.thetype, other)
|
||||
self.assertTrue(self.thetype is not other, f"{self.thetype=}, {other=}")
|
||||
|
||||
@make_dynamo_test
|
||||
def test_dict___iter__(self):
|
||||
d = self.thetype({1: 2})
|
||||
it = d.__iter__()
|
||||
self.assertEqual(next(it), 1)
|
||||
|
||||
|
||||
class DictSubclassMethodsTests(DictMethodsTests):
|
||||
thetype = SimpleDict
|
||||
|
|
|
|||
|
|
@ -168,6 +168,14 @@ class TupleTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertRaises(TypeError, p.__contains__)
|
||||
self.assertRaises(TypeError, p.__contains__, 1, 2)
|
||||
|
||||
@make_dynamo_test
|
||||
def test___iter__(self):
|
||||
p = self.thetype([1])
|
||||
it = p.__iter__()
|
||||
self.assertEqual(next(it), 1)
|
||||
it = p.__iter__().__iter__()
|
||||
self.assertEqual(next(it), 1)
|
||||
|
||||
|
||||
class ListTests(TupleTests):
|
||||
# List methods
|
||||
|
|
|
|||
|
|
@ -1272,6 +1272,20 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
|||
r2 = opt_fn(d)
|
||||
self.assertEqual(r1, r2)
|
||||
|
||||
def test_tensor__iter__(self):
|
||||
def fn(x):
|
||||
it = x.__iter__()
|
||||
for y in it:
|
||||
y.add_(1.0)
|
||||
return y
|
||||
|
||||
torch._dynamo.testing.standard_test(
|
||||
self,
|
||||
fn,
|
||||
1,
|
||||
expected_ops=20,
|
||||
)
|
||||
|
||||
def test_tensor_iter(self):
|
||||
def fn(x):
|
||||
for y in x:
|
||||
|
|
@ -1961,6 +1975,15 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
|||
self.assertTrue(same(res2, torch.ones(2)))
|
||||
self.assertTrue(same(res3, torch.ones(3)))
|
||||
|
||||
def test_range___iter__(self):
|
||||
def func(x):
|
||||
it = range(3).__iter__()
|
||||
return x + next(it)
|
||||
|
||||
opt_func = torch.compile(func, backend="eager", fullgraph=True)
|
||||
x = torch.randn(3)
|
||||
self.assertTrue(same(func(x), opt_func(x)))
|
||||
|
||||
def test_range_iter_side_effects(self):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def run(x, it):
|
||||
|
|
@ -9608,6 +9631,18 @@ def ___make_guard_fn():
|
|||
self.assertEqual(msg, "shape torch.Size([8, 8]) batch size 1")
|
||||
self.assertEqual(res, img1 + torch.sin(img1))
|
||||
|
||||
def test_str___iter__(self):
|
||||
def fn(x):
|
||||
s = "a"
|
||||
if next(s.__iter__()) == "a":
|
||||
return x + 1
|
||||
else:
|
||||
return x
|
||||
|
||||
x = torch.randn(3)
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
self.assertEqual(fn(x), opt_fn(x))
|
||||
|
||||
def test_str_format_return2(self):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(img):
|
||||
|
|
|
|||
|
|
@ -2763,6 +2763,22 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
|||
self.assertEqual(eager_res, optim_res)
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
|
||||
def test_specialized_module___iter__(self):
|
||||
ml = torch.nn.ModuleList(
|
||||
[
|
||||
torch.nn.Linear(10, 10),
|
||||
]
|
||||
)
|
||||
ml.torchdynamo_force_dynamic = False
|
||||
|
||||
def f(x):
|
||||
it = ml.__iter__()
|
||||
return next(it)(x)
|
||||
|
||||
opt_f = torch.compile(f, backend="eager", fullgraph=True)
|
||||
x = torch.randn(10)
|
||||
self.assertEqual(f(x), opt_f(x))
|
||||
|
||||
def test_module_dict_iter_keys(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from ..utils import (
|
|||
np,
|
||||
raise_args_mismatch,
|
||||
)
|
||||
from .base import VariableTracker
|
||||
from .base import ValueMutationNew, VariableTracker
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -168,6 +168,14 @@ its type to `common_constant_types`.
|
|||
return ConstantVariable.create(self.value.join(arg_const))
|
||||
except NotImplementedError:
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
elif name == "__iter__" and istype(self.value, str):
|
||||
# this could be some generic iterator to avoid the circular import,
|
||||
# but ListIterator does what we want
|
||||
from .lists import ListIteratorVariable
|
||||
|
||||
return ListIteratorVariable(
|
||||
self.unpack_var_sequence(tx), mutation_type=ValueMutationNew()
|
||||
)
|
||||
|
||||
if any(isinstance(x, SymNodeVariable) for x in args):
|
||||
# Promote to SymNodeVariable for operations involving dynamic shapes.
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ from .. import graph_break_hints, polyfills, variables
|
|||
from ..bytecode_transformation import create_call_function, create_instruction
|
||||
from ..exc import raise_observed_exception, unimplemented_v2
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..source import is_from_local_source
|
||||
from ..source import is_constant_source, is_from_local_source
|
||||
from ..utils import (
|
||||
cmp_name_to_op_mapping,
|
||||
dict_items,
|
||||
|
|
@ -46,6 +46,7 @@ from ..utils import (
|
|||
)
|
||||
from .base import ValueMutationNew, VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
from .lists import ListIteratorVariable
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -779,6 +780,12 @@ class ConstDictVariable(VariableTracker):
|
|||
elif name == "__ior__":
|
||||
self.call_method(tx, "update", args, kwargs)
|
||||
return self
|
||||
elif name == "__iter__":
|
||||
if self.source and not is_constant_source(self.source):
|
||||
tx.output.guard_on_key_order.add(self.source)
|
||||
return ListIteratorVariable(
|
||||
self.unpack_var_sequence(tx), mutation_type=ValueMutationNew()
|
||||
)
|
||||
else:
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
|
|
@ -787,12 +794,16 @@ class ConstDictVariable(VariableTracker):
|
|||
return [x.vt for x in self.items.keys()]
|
||||
|
||||
def call_obj_hasattr(self, tx, name):
|
||||
# dict not allow setting arbitrary attributes. To check for hasattr, we can just check the __dict__ of the dict.
|
||||
# OrderedDict though requires side effects tracking because it supports arbitrary setattr.
|
||||
if self.user_cls is dict:
|
||||
if name in self.user_cls.__dict__:
|
||||
# dict not allow setting arbitrary attributes. OrderedDict and
|
||||
# defaultdict allow arbitrary setattr, but not deletion of default attrs
|
||||
if any(
|
||||
self.user_cls is t
|
||||
for t in (dict, collections.OrderedDict, collections.defaultdict)
|
||||
):
|
||||
if hasattr(self.user_cls, name):
|
||||
return ConstantVariable.create(True)
|
||||
return ConstantVariable.create(False)
|
||||
if self.user_cls is dict:
|
||||
return ConstantVariable.create(False)
|
||||
|
||||
msg = f"hasattr on {self.user_cls} is not supported"
|
||||
unimplemented_v2(
|
||||
|
|
@ -879,6 +890,13 @@ class MappingProxyVariable(VariableTracker):
|
|||
)
|
||||
return self.dv_dict.call_method(tx, name, args, kwargs)
|
||||
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> "VariableTracker":
|
||||
if self.python_type() is types.MappingProxyType:
|
||||
return ConstantVariable.create(name in types.MappingProxyType.__dict__)
|
||||
return super().call_obj_hasattr(tx, name)
|
||||
|
||||
|
||||
class NNModuleHooksDictVariable(ConstDictVariable):
|
||||
# Special class to avoid adding any guards on the nn module hook ids.
|
||||
|
|
@ -1388,6 +1406,10 @@ class DictViewVariable(VariableTracker):
|
|||
) -> "VariableTracker":
|
||||
if name == "__len__":
|
||||
return self.dv_dict.call_method(tx, name, args, kwargs)
|
||||
elif name == "__iter__":
|
||||
return ListIteratorVariable(
|
||||
self.view_items_vt, mutation_type=ValueMutationNew()
|
||||
)
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -252,6 +252,26 @@ class IteratorVariable(VariableTracker):
|
|||
def has_force_unpack_var_sequence(self, tx) -> bool:
|
||||
return True
|
||||
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> "VariableTracker":
|
||||
if name == "__iter__" or name == "__next__":
|
||||
return variables.ConstantVariable.create(True)
|
||||
super().call_obj_hasattr(tx, name)
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
if name == "__iter__":
|
||||
return self
|
||||
elif name == "__next__":
|
||||
return self.next_variable(tx)
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
|
||||
class ObjectIteratorVariable(IteratorVariable):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -294,6 +294,10 @@ class BaseListVariable(VariableTracker):
|
|||
[variables.BuiltinVariable(cmp_name_to_op_mapping[name]), left, right],
|
||||
{},
|
||||
)
|
||||
elif name == "__iter__":
|
||||
return ListIteratorVariable(
|
||||
list(self.items), mutation_type=ValueMutationNew()
|
||||
)
|
||||
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
|
|
@ -472,9 +476,9 @@ class RangeVariable(BaseListVariable):
|
|||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> "VariableTracker":
|
||||
if self.python_type() is not range:
|
||||
return super().call_obj_hasattr(tx, name)
|
||||
return variables.ConstantVariable.create(hasattr(range(0), name))
|
||||
if self.python_type() is range:
|
||||
return variables.ConstantVariable.create(name in range.__dict__)
|
||||
return super().call_obj_hasattr(tx, name)
|
||||
|
||||
def range_equals(self, other: "RangeVariable"):
|
||||
r0, r1 = self, other
|
||||
|
|
@ -1064,6 +1068,13 @@ class DequeVariable(CommonListMethodsVariable):
|
|||
self.items[:] = self.items[slice_within_maxlen]
|
||||
return result
|
||||
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> "VariableTracker":
|
||||
if self.python_type() is collections.deque:
|
||||
return variables.ConstantVariable.create(name in collections.deque.__dict__)
|
||||
return super().call_obj_hasattr(tx, name)
|
||||
|
||||
|
||||
class TupleVariable(BaseListVariable):
|
||||
def python_type(self):
|
||||
|
|
|
|||
|
|
@ -796,6 +796,10 @@ class NNModuleVariable(VariableTracker):
|
|||
f"{len(args)} args and {len(kwargs)} kwargs",
|
||||
)
|
||||
return ConstantVariable.create(len(module))
|
||||
elif name == "__iter__":
|
||||
return ListIteratorVariable(
|
||||
self.unpack_var_sequence(tx), mutation_type=ValueMutationNew()
|
||||
)
|
||||
elif (
|
||||
name == "__contains__"
|
||||
and isinstance(module, (torch.nn.ModuleDict, torch.nn.ParameterDict))
|
||||
|
|
|
|||
|
|
@ -66,9 +66,9 @@ from ..utils import (
|
|||
set_example_value,
|
||||
tensortype_to_dtype,
|
||||
)
|
||||
from .base import AttributeMutationNew, VariableTracker
|
||||
from .base import AttributeMutationNew, ValueMutationNew, VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
from .lists import SizeVariable
|
||||
from .lists import ListIteratorVariable, SizeVariable
|
||||
from .user_defined import UserDefinedClassVariable
|
||||
|
||||
|
||||
|
|
@ -427,7 +427,7 @@ class TensorVariable(VariableTracker):
|
|||
# Today, var_getattr returns GetAttrVariable for both non-existent
|
||||
# attributes and existing attributes. This is a bug and requires more
|
||||
# deep dive.
|
||||
if name in ("size", "stride"):
|
||||
if name in ("size", "stride", "__iter__"):
|
||||
return ConstantVariable(True)
|
||||
|
||||
try:
|
||||
|
|
@ -1079,6 +1079,14 @@ class TensorVariable(VariableTracker):
|
|||
tx = InstructionTranslator.current_tx()
|
||||
return self.call_method(tx, "size", [ConstantVariable.create(0)], {})
|
||||
|
||||
def method___iter__(self):
|
||||
from ..symbolic_convert import InstructionTranslator
|
||||
|
||||
tx = InstructionTranslator.current_tx()
|
||||
return ListIteratorVariable(
|
||||
self.unpack_var_sequence(tx), mutation_type=ValueMutationNew()
|
||||
)
|
||||
|
||||
def method_addcmul_(self, tensor1, tensor2, *, value=None):
|
||||
from ..symbolic_convert import InstructionTranslator
|
||||
|
||||
|
|
@ -1612,7 +1620,7 @@ class NumpyNdarrayVariable(TensorVariable):
|
|||
),
|
||||
hints=[*graph_break_hints.FUNDAMENTAL],
|
||||
)
|
||||
if name in ["__len__", "size", "tolist"]:
|
||||
if name in ["__len__", "size", "tolist", "__iter__"]:
|
||||
# delegate back to TensorVariable
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
if name in ("tostring", "tobytes", "__delattr__"):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user