mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Fix ListIterator tracking mutations to original list (#166350)
Currently ListIteratorVariable copies the underlying list, which prevents it from seeing mutations to the original list. Remove the copy to match cpython behavior. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166350 Approved by: https://github.com/guilhermeleobas ghstack dependencies: #166349, #162768
This commit is contained in:
parent
8101fd46d4
commit
e0604d3170
|
|
@ -295,9 +295,7 @@ class BaseListVariable(VariableTracker):
|
||||||
{},
|
{},
|
||||||
)
|
)
|
||||||
elif name == "__iter__":
|
elif name == "__iter__":
|
||||||
return ListIteratorVariable(
|
return ListIteratorVariable(self.items, mutation_type=ValueMutationNew())
|
||||||
list(self.items), mutation_type=ValueMutationNew()
|
|
||||||
)
|
|
||||||
|
|
||||||
return super().call_method(tx, name, args, kwargs)
|
return super().call_method(tx, name, args, kwargs)
|
||||||
|
|
||||||
|
|
@ -1595,6 +1593,7 @@ class ListIteratorVariable(IteratorVariable):
|
||||||
# assert all(isinstance(x, VariableTracker) for x in items)
|
# assert all(isinstance(x, VariableTracker) for x in items)
|
||||||
self.items = items
|
self.items = items
|
||||||
self.index = index
|
self.index = index
|
||||||
|
self.is_exhausted = False
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"{self.__class__.__name__}(length={len(self.items)}, index={repr(self.index)})"
|
return f"{self.__class__.__name__}(length={len(self.items)}, index={repr(self.index)})"
|
||||||
|
|
@ -1602,7 +1601,8 @@ class ListIteratorVariable(IteratorVariable):
|
||||||
def next_variable(self, tx):
|
def next_variable(self, tx):
|
||||||
assert self.is_mutable()
|
assert self.is_mutable()
|
||||||
old_index = self.index
|
old_index = self.index
|
||||||
if old_index >= len(self.items):
|
if old_index >= len(self.items) or self.is_exhausted:
|
||||||
|
self.is_exhausted = True
|
||||||
raise_observed_exception(StopIteration, tx)
|
raise_observed_exception(StopIteration, tx)
|
||||||
|
|
||||||
tx.output.side_effects.mutation(self)
|
tx.output.side_effects.mutation(self)
|
||||||
|
|
@ -1624,15 +1624,19 @@ class ListIteratorVariable(IteratorVariable):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def unpack_var_sequence(self, tx):
|
def unpack_var_sequence(self, tx):
|
||||||
r = list(self.items[self.index :])
|
if self.is_exhausted:
|
||||||
self.index = len(self.items)
|
return []
|
||||||
return r
|
self.is_exhausted = True
|
||||||
|
return list(self.items[self.index :])
|
||||||
|
|
||||||
def force_unpack_var_sequence(self, tx) -> list[VariableTracker]:
|
def force_unpack_var_sequence(self, tx) -> list[VariableTracker]:
|
||||||
return self.unpack_var_sequence(tx)
|
return self.unpack_var_sequence(tx)
|
||||||
|
|
||||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
remaining_items = self.items[self.index :]
|
if not self.is_exhausted:
|
||||||
|
remaining_items = self.items[self.index :]
|
||||||
|
else:
|
||||||
|
remaining_items = []
|
||||||
codegen.foreach(remaining_items)
|
codegen.foreach(remaining_items)
|
||||||
codegen.extend_output(
|
codegen.extend_output(
|
||||||
[
|
[
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user