[dynamo] Fix ListIterator tracking mutations to original list (#166350)

Currently ListIteratorVariable copies the underlying list, which prevents it
from seeing mutations to the original list.  Remove the copy to match cpython behavior.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166350
Approved by: https://github.com/guilhermeleobas
ghstack dependencies: #166349, #162768
This commit is contained in:
Rob Timpe 2025-10-28 20:52:20 +00:00 committed by PyTorch MergeBot
parent 8101fd46d4
commit e0604d3170
2 changed files with 12 additions and 8 deletions

View File

@ -295,9 +295,7 @@ class BaseListVariable(VariableTracker):
{}, {},
) )
elif name == "__iter__": elif name == "__iter__":
return ListIteratorVariable( return ListIteratorVariable(self.items, mutation_type=ValueMutationNew())
list(self.items), mutation_type=ValueMutationNew()
)
return super().call_method(tx, name, args, kwargs) return super().call_method(tx, name, args, kwargs)
@ -1595,6 +1593,7 @@ class ListIteratorVariable(IteratorVariable):
# assert all(isinstance(x, VariableTracker) for x in items) # assert all(isinstance(x, VariableTracker) for x in items)
self.items = items self.items = items
self.index = index self.index = index
self.is_exhausted = False
def __repr__(self) -> str: def __repr__(self) -> str:
return f"{self.__class__.__name__}(length={len(self.items)}, index={repr(self.index)})" return f"{self.__class__.__name__}(length={len(self.items)}, index={repr(self.index)})"
@ -1602,7 +1601,8 @@ class ListIteratorVariable(IteratorVariable):
def next_variable(self, tx): def next_variable(self, tx):
assert self.is_mutable() assert self.is_mutable()
old_index = self.index old_index = self.index
if old_index >= len(self.items): if old_index >= len(self.items) or self.is_exhausted:
self.is_exhausted = True
raise_observed_exception(StopIteration, tx) raise_observed_exception(StopIteration, tx)
tx.output.side_effects.mutation(self) tx.output.side_effects.mutation(self)
@ -1624,15 +1624,19 @@ class ListIteratorVariable(IteratorVariable):
return True return True
def unpack_var_sequence(self, tx): def unpack_var_sequence(self, tx):
r = list(self.items[self.index :]) if self.is_exhausted:
self.index = len(self.items) return []
return r self.is_exhausted = True
return list(self.items[self.index :])
def force_unpack_var_sequence(self, tx) -> list[VariableTracker]: def force_unpack_var_sequence(self, tx) -> list[VariableTracker]:
return self.unpack_var_sequence(tx) return self.unpack_var_sequence(tx)
def reconstruct(self, codegen: "PyCodegen") -> None: def reconstruct(self, codegen: "PyCodegen") -> None:
remaining_items = self.items[self.index :] if not self.is_exhausted:
remaining_items = self.items[self.index :]
else:
remaining_items = []
codegen.foreach(remaining_items) codegen.foreach(remaining_items)
codegen.extend_output( codegen.extend_output(
[ [