mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111725 Approved by: https://github.com/voznesenskym ghstack dependencies: #111306, #111415
101 lines
3.2 KiB
Python
101 lines
3.2 KiB
Python
MAX_CYCLE = 3000
|
|
|
|
from typing import List, Optional
|
|
|
|
from ..exc import unimplemented
|
|
|
|
from .base import VariableTracker
|
|
from .constant import ConstantVariable
|
|
|
|
|
|
class IteratorVariable(VariableTracker):
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
def next_variables(self, tx):
|
|
unimplemented("abstract method, must implement")
|
|
|
|
|
|
class RepeatIteratorVariable(IteratorVariable):
|
|
def __init__(self, item: VariableTracker, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.item = item
|
|
|
|
# Repeat needs no mutation, clone self
|
|
def next_variables(self, tx):
|
|
return self.item.clone(), self
|
|
|
|
|
|
class CountIteratorVariable(IteratorVariable):
|
|
def __init__(self, item: int = 0, step: int = 1, **kwargs):
|
|
super().__init__(**kwargs)
|
|
if not isinstance(item, VariableTracker):
|
|
item = ConstantVariable.create(item)
|
|
if not isinstance(step, VariableTracker):
|
|
step = ConstantVariable.create(step)
|
|
self.item = item
|
|
self.step = step
|
|
|
|
def next_variables(self, tx):
|
|
assert self.mutable_local
|
|
next_item = self.item.call_method(tx, "__add__", [self.step], {})
|
|
next_iter = self.clone(item=next_item)
|
|
tx.replace_all(self, next_iter)
|
|
return self.item, next_iter
|
|
|
|
|
|
class CycleIteratorVariable(IteratorVariable):
|
|
def __init__(
|
|
self,
|
|
iterator: IteratorVariable,
|
|
saved: List[VariableTracker] = None,
|
|
saved_index: int = 0,
|
|
item: Optional[VariableTracker] = None,
|
|
**kwargs,
|
|
):
|
|
if saved is None:
|
|
saved = []
|
|
super().__init__(**kwargs)
|
|
self.iterator = iterator
|
|
self.saved = saved
|
|
self.saved_index = saved_index
|
|
self.item = item
|
|
|
|
def next_variables(self, tx):
|
|
assert self.mutable_local
|
|
|
|
if self.iterator is not None:
|
|
try:
|
|
new_item, next_inner_iter = self.iterator.next_variables(tx)
|
|
tx.replace_all(self.iterator, next_inner_iter)
|
|
if len(self.saved) > MAX_CYCLE:
|
|
unimplemented(
|
|
"input iterator to itertools.cycle has too many items"
|
|
)
|
|
next_iter = self.clone(
|
|
iterator=next_inner_iter,
|
|
saved=self.saved + [new_item],
|
|
item=new_item,
|
|
)
|
|
|
|
tx.replace_all(self, next_iter)
|
|
if self.item is None:
|
|
return next_iter.next_variables(tx)
|
|
return self.item, next_iter
|
|
except StopIteration:
|
|
next_iter = self.clone(iterator=None)
|
|
# this is redundant as next_iter will do the same
|
|
# but we do it anyway for safety
|
|
tx.replace_all(self, next_iter)
|
|
return next_iter.next_variables(tx)
|
|
elif len(self.saved) > 0:
|
|
next_iter = self.clone(
|
|
saved_index=(self.saved_index + 1) % len(self.saved),
|
|
item=self.saved[self.saved_index],
|
|
)
|
|
tx.replace_all(self, next_iter)
|
|
return self.item, next_iter
|
|
else:
|
|
raise StopIteration
|
|
return self.item, next_iter
|