MAX_CYCLE = 3000 from typing import List, Optional from ..exc import unimplemented from .base import VariableTracker from .constant import ConstantVariable class IteratorVariable(VariableTracker): def __init__(self, **kwargs): super().__init__(**kwargs) def next_variables(self, tx): unimplemented("abstract method, must implement") class RepeatIteratorVariable(IteratorVariable): def __init__(self, item: VariableTracker, **kwargs): super().__init__(**kwargs) self.item = item # Repeat needs no mutation, clone self def next_variables(self, tx): return self.item.clone(), self class CountIteratorVariable(IteratorVariable): def __init__(self, item: int = 0, step: int = 1, **kwargs): super().__init__(**kwargs) if not isinstance(item, VariableTracker): item = ConstantVariable.create(item) if not isinstance(step, VariableTracker): step = ConstantVariable.create(step) self.item = item self.step = step def next_variables(self, tx): assert self.mutable_local next_item = self.item.call_method(tx, "__add__", [self.step], {}) next_iter = self.clone(item=next_item) tx.replace_all(self, next_iter) return self.item, next_iter class CycleIteratorVariable(IteratorVariable): def __init__( self, iterator: IteratorVariable, saved: List[VariableTracker] = None, saved_index: int = 0, item: Optional[VariableTracker] = None, **kwargs, ): if saved is None: saved = [] super().__init__(**kwargs) self.iterator = iterator self.saved = saved self.saved_index = saved_index self.item = item def next_variables(self, tx): assert self.mutable_local if self.iterator is not None: try: new_item, next_inner_iter = self.iterator.next_variables(tx) tx.replace_all(self.iterator, next_inner_iter) if len(self.saved) > MAX_CYCLE: unimplemented( "input iterator to itertools.cycle has too many items" ) next_iter = self.clone( iterator=next_inner_iter, saved=self.saved + [new_item], item=new_item, ) tx.replace_all(self, next_iter) if self.item is None: return next_iter.next_variables(tx) return self.item, next_iter except StopIteration: next_iter = self.clone(iterator=None) # this is redundant as next_iter will do the same # but we do it anyway for safety tx.replace_all(self, next_iter) return next_iter.next_variables(tx) elif len(self.saved) > 0: next_iter = self.clone( saved_index=(self.saved_index + 1) % len(self.saved), item=self.saved[self.saved_index], ) tx.replace_all(self, next_iter) return self.item, next_iter else: raise StopIteration return self.item, next_iter