pytorch/torch/_dynamo/variables/iter.py
Michael Lazos fbeca60b1f Remove replace_all and make VTs mutable (#113725)
1.  Removes calls to `replace_all` and `clone` and makes VTs mutable.
2. Properly handles Tuple Iterator mutation. Previously TupleIterator variables would only be properly reconstructed if they were advanced at least once in a frame. On calls to `next`, the source information would be lost (due to constructing a new iterator without using builder), which would ensure that during codegen the variable would be reconstructed from scratch. Now that VTs are mutated, the source is never lost, so we need to properly track mutation and handle it by replaying calls to `next` at the end of the modified bytecode.
3. Added test for checking iadd side effects, this was missing in our unit test coverage.
4. Fixed two incorrect sources, DelayGraphBreakVariable, and UserMethodVariable both relied on setting the source to AttrSource(parent, name) at the callsite of `var_getattr`.
5. Fixed a bug in inplace adding for lists, it would set the resulting VariableTracker's source to `None` which would utilize a different reconstruct path in codegen. Now this is handled explicitly by reconstructing vars when allow_cache=`False`, so that during side effect replay, the mutated var is correctly updated.

In subsequent PRs:
* Refactoring side effect tracking to be significantly simpler (I think we only need an `is_modified` flag)
* Refactor `next_variables` iterator to match the signature of `next`
* Remove all references to `options` in the code
* Refactor VTs representing mutable collections to implement their own mutation update handling
* Remove clone and/or make it specific to lists for creating slices
* Add mutation tracking/replay for sets
* Add mutation tracking/replay for iter.py
* Removing setting source in builder (it's set at the top level after a var is returned)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113725
Approved by: https://github.com/jansel
2023-12-10 09:31:21 +00:00

89 lines
2.7 KiB
Python

MAX_CYCLE = 3000
from typing import List, Optional
from ..exc import unimplemented
from .base import VariableTracker
from .constant import ConstantVariable
class IteratorVariable(VariableTracker):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def next_variables(self, tx):
unimplemented("abstract method, must implement")
class RepeatIteratorVariable(IteratorVariable):
def __init__(self, item: VariableTracker, **kwargs):
super().__init__(**kwargs)
self.item = item
# Repeat needs no mutation, clone self
def next_variables(self, tx):
return self.item, self
class CountIteratorVariable(IteratorVariable):
def __init__(self, item: int = 0, step: int = 1, **kwargs):
super().__init__(**kwargs)
if not isinstance(item, VariableTracker):
item = ConstantVariable.create(item)
if not isinstance(step, VariableTracker):
step = ConstantVariable.create(step)
self.item = item
self.step = step
def next_variables(self, tx):
assert self.mutable_local
tx.output.side_effects.mutation(self)
next_item = self.item.call_method(tx, "__add__", [self.step], {})
self.item = next_item
return self.item, self
class CycleIteratorVariable(IteratorVariable):
def __init__(
self,
iterator: IteratorVariable,
saved: List[VariableTracker] = None,
saved_index: int = 0,
item: Optional[VariableTracker] = None,
**kwargs,
):
if saved is None:
saved = []
super().__init__(**kwargs)
self.iterator = iterator
self.saved = saved
self.saved_index = saved_index
self.item = item
def next_variables(self, tx):
assert self.mutable_local
if self.iterator is not None:
try:
new_item, _ = self.iterator.next_variables(tx)
if len(self.saved) > MAX_CYCLE:
unimplemented(
"input iterator to itertools.cycle has too many items"
)
tx.output.side_effects.mutation(self)
self.saved.append(new_item)
self.item = new_item
if self.item is None:
return self.next_variables(tx)
return self.item, self
except StopIteration:
self.iterator = None
return self.next_variables(tx)
elif len(self.saved) > 0:
tx.output.side_effects.mutation(self)
self.saved_index = (self.saved_index + 1) % len(self.saved)
return self.item, self
else:
raise StopIteration