mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[dynamo] Fix ListIterator tracking mutations to original list (#166350)
Currently ListIteratorVariable copies the underlying list, which prevents it from seeing mutations to the original list. Remove the copy to match cpython behavior. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166350 Approved by: https://github.com/guilhermeleobas ghstack dependencies: #166349, #162768
This commit is contained in:
parent
8101fd46d4
commit
e0604d3170
|
|
@ -295,9 +295,7 @@ class BaseListVariable(VariableTracker):
|
|||
{},
|
||||
)
|
||||
elif name == "__iter__":
|
||||
return ListIteratorVariable(
|
||||
list(self.items), mutation_type=ValueMutationNew()
|
||||
)
|
||||
return ListIteratorVariable(self.items, mutation_type=ValueMutationNew())
|
||||
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
|
|
@ -1595,6 +1593,7 @@ class ListIteratorVariable(IteratorVariable):
|
|||
# assert all(isinstance(x, VariableTracker) for x in items)
|
||||
self.items = items
|
||||
self.index = index
|
||||
self.is_exhausted = False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(length={len(self.items)}, index={repr(self.index)})"
|
||||
|
|
@ -1602,7 +1601,8 @@ class ListIteratorVariable(IteratorVariable):
|
|||
def next_variable(self, tx):
|
||||
assert self.is_mutable()
|
||||
old_index = self.index
|
||||
if old_index >= len(self.items):
|
||||
if old_index >= len(self.items) or self.is_exhausted:
|
||||
self.is_exhausted = True
|
||||
raise_observed_exception(StopIteration, tx)
|
||||
|
||||
tx.output.side_effects.mutation(self)
|
||||
|
|
@ -1624,15 +1624,19 @@ class ListIteratorVariable(IteratorVariable):
|
|||
return True
|
||||
|
||||
def unpack_var_sequence(self, tx):
|
||||
r = list(self.items[self.index :])
|
||||
self.index = len(self.items)
|
||||
return r
|
||||
if self.is_exhausted:
|
||||
return []
|
||||
self.is_exhausted = True
|
||||
return list(self.items[self.index :])
|
||||
|
||||
def force_unpack_var_sequence(self, tx) -> list[VariableTracker]:
|
||||
return self.unpack_var_sequence(tx)
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
remaining_items = self.items[self.index :]
|
||||
if not self.is_exhausted:
|
||||
remaining_items = self.items[self.index :]
|
||||
else:
|
||||
remaining_items = []
|
||||
codegen.foreach(remaining_items)
|
||||
codegen.extend_output(
|
||||
[
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user