mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
1. Removes calls to `replace_all` and `clone` and makes VTs mutable. 2. Properly handles Tuple Iterator mutation. Previously TupleIterator variables would only be properly reconstructed if they were advanced at least once in a frame. On calls to `next`, the source information would be lost (due to constructing a new iterator without using builder), which would ensure that during codegen the variable would be reconstructed from scratch. Now that VTs are mutated, the source is never lost, so we need to properly track mutation and handle it by replaying calls to `next` at the end of the modified bytecode. 3. Added test for checking iadd side effects, this was missing in our unit test coverage. 4. Fixed two incorrect sources, DelayGraphBreakVariable, and UserMethodVariable both relied on setting the source to AttrSource(parent, name) at the callsite of `var_getattr`. 5. Fixed a bug in inplace adding for lists, it would set the resulting VariableTracker's source to `None` which would utilize a different reconstruct path in codegen. Now this is handled explicitly by reconstructing vars when allow_cache=`False`, so that during side effect replay, the mutated var is correctly updated. In subsequent PRs: * Refactoring side effect tracking to be significantly simpler (I think we only need an `is_modified` flag) * Refactor `next_variables` iterator to match the signature of `next` * Remove all references to `options` in the code * Refactor VTs representing mutable collections to implement their own mutation update handling * Remove clone and/or make it specific to lists for creating slices * Add mutation tracking/replay for sets * Add mutation tracking/replay for iter.py * Removing setting source in builder (it's set at the top level after a var is returned) Pull Request resolved: https://github.com/pytorch/pytorch/pull/113725 Approved by: https://github.com/jansel
89 lines
2.7 KiB
Python
89 lines
2.7 KiB
Python
MAX_CYCLE = 3000
|
|
|
|
from typing import List, Optional
|
|
|
|
from ..exc import unimplemented
|
|
|
|
from .base import VariableTracker
|
|
from .constant import ConstantVariable
|
|
|
|
|
|
class IteratorVariable(VariableTracker):
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
def next_variables(self, tx):
|
|
unimplemented("abstract method, must implement")
|
|
|
|
|
|
class RepeatIteratorVariable(IteratorVariable):
|
|
def __init__(self, item: VariableTracker, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.item = item
|
|
|
|
# Repeat needs no mutation, clone self
|
|
def next_variables(self, tx):
|
|
return self.item, self
|
|
|
|
|
|
class CountIteratorVariable(IteratorVariable):
|
|
def __init__(self, item: int = 0, step: int = 1, **kwargs):
|
|
super().__init__(**kwargs)
|
|
if not isinstance(item, VariableTracker):
|
|
item = ConstantVariable.create(item)
|
|
if not isinstance(step, VariableTracker):
|
|
step = ConstantVariable.create(step)
|
|
self.item = item
|
|
self.step = step
|
|
|
|
def next_variables(self, tx):
|
|
assert self.mutable_local
|
|
tx.output.side_effects.mutation(self)
|
|
next_item = self.item.call_method(tx, "__add__", [self.step], {})
|
|
self.item = next_item
|
|
return self.item, self
|
|
|
|
|
|
class CycleIteratorVariable(IteratorVariable):
|
|
def __init__(
|
|
self,
|
|
iterator: IteratorVariable,
|
|
saved: List[VariableTracker] = None,
|
|
saved_index: int = 0,
|
|
item: Optional[VariableTracker] = None,
|
|
**kwargs,
|
|
):
|
|
if saved is None:
|
|
saved = []
|
|
super().__init__(**kwargs)
|
|
self.iterator = iterator
|
|
self.saved = saved
|
|
self.saved_index = saved_index
|
|
self.item = item
|
|
|
|
def next_variables(self, tx):
|
|
assert self.mutable_local
|
|
|
|
if self.iterator is not None:
|
|
try:
|
|
new_item, _ = self.iterator.next_variables(tx)
|
|
if len(self.saved) > MAX_CYCLE:
|
|
unimplemented(
|
|
"input iterator to itertools.cycle has too many items"
|
|
)
|
|
tx.output.side_effects.mutation(self)
|
|
self.saved.append(new_item)
|
|
self.item = new_item
|
|
if self.item is None:
|
|
return self.next_variables(tx)
|
|
return self.item, self
|
|
except StopIteration:
|
|
self.iterator = None
|
|
return self.next_variables(tx)
|
|
elif len(self.saved) > 0:
|
|
tx.output.side_effects.mutation(self)
|
|
self.saved_index = (self.saved_index + 1) % len(self.saved)
|
|
return self.item, self
|
|
else:
|
|
raise StopIteration
|