diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 9b465726754..69145ce86d8 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -14,36 +14,15 @@ import torch def index(iterator, item, start=0, end=None): - for i, elem in islice(enumerate(iterator), start, end): + import itertools + + for i, elem in itertools.islice(enumerate(iterator), start, end): if item == elem: return i # This will not run in dynamo raise ValueError(f"{item} is not in {type(iterator)}") -def islice(iterator, start=0, end=None, step=1): - if start < 0 or (end is not None and end < 0) or step < 0: - raise ValueError("Indices must be non-negative") - if step == 0: - raise ValueError("Step cannot be 0") - - it = iter(iterator) - - for _ in range(start): - next(it) - - if end is None: - for i, element in enumerate(it): - if i % step == 0: - yield element - else: - for i, element in enumerate(it): - if i % step == 0 and i + start < end - start: - yield element - elif i + start >= end - start: - break - - def repeat(item, count): for i in range(count): yield item diff --git a/torch/_dynamo/polyfills/itertools.py b/torch/_dynamo/polyfills/itertools.py index 802a62a82c8..a829d2e9b88 100644 --- a/torch/_dynamo/polyfills/itertools.py +++ b/torch/_dynamo/polyfills/itertools.py @@ -13,6 +13,7 @@ from ..decorators import substitute_in_graph __all__ = [ "chain", "chain_from_iterable", + "islice", "tee", ] @@ -35,6 +36,35 @@ def chain_from_iterable(iterable: Iterable[Iterable[_T]], /) -> Iterator[_T]: chain.from_iterable = chain_from_iterable # type: ignore[method-assign] +# Reference: https://docs.python.org/3/library/itertools.html#itertools.islice +@substitute_in_graph(itertools.islice, is_embedded_type=True) # type: ignore[arg-type] +def islice(iterable: Iterable[_T], /, *args: int | None) -> Iterator[_T]: + s = slice(*args) + start = 0 if s.start is None else s.start + stop = s.stop + step = 1 if s.step is None else s.step + if start < 0 or (stop is not None and stop < 0) or step <= 0: + raise ValueError( + "Indices for islice() must be None or an integer: 0 <= x <= sys.maxsize.", + ) + + if stop is None: + # TODO: use indices = itertools.count() and merge implementation with the else branch + # when we support infinite iterators + next_i = start + for i, element in enumerate(iterable): + if i == next_i: + yield element + next_i += step + else: + indices = range(max(start, stop)) + next_i = start + for i, element in zip(indices, iterable): + if i == next_i: + yield element + next_i += step + + # Reference: https://docs.python.org/3/library/itertools.html#itertools.tee @substitute_in_graph(itertools.tee) def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]: diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 464b411a8b9..e8692fe4789 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -12,7 +12,6 @@ import enum import functools import importlib import inspect -import itertools import linecache import logging import multiprocessing @@ -2993,7 +2992,6 @@ def _builtin_function_ids() -> Dict[int, str]: if not k.startswith("_") and callable(v) } ) - rv.update({id(v): f"itertools.{v.__name__}" for v in (itertools.islice,)}) rv.update( { id(cast): "typing.cast", diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 279981465b2..919daf6fbd2 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1873,15 +1873,6 @@ class BuiltinVariable(VariableTracker): ) return variables.ListVariable(items) - def call_islice(self, tx: "InstructionTranslator", iterable, *args): - if iterable.has_unpack_var_sequence(tx) and all( - x.is_python_constant() for x in args - ): - const_args = [x.as_python_constant() for x in args] - items = iterable.unpack_var_sequence(tx) - items = list(itertools.islice(items, *const_args)) - return variables.TupleVariable(items) - # neg is a constant fold function, so we only get here if constant fold is not valid def call_neg(self, tx: "InstructionTranslator", a): if isinstance(a, SymNodeVariable): diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index ed3ae786634..34c2354026a 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -166,12 +166,6 @@ class ItertoolsVariable(VariableTracker): from_exc=e, ) return variables.ListIteratorVariable(result, mutable_local=MutableLocal()) - elif self.value is itertools.islice: - from .builder import SourcelessBuilder - - return tx.inline_user_function_return( - SourcelessBuilder.create(tx, polyfills.islice), args, kwargs - ) elif self.value is itertools.repeat: if len(args) < 2: return variables.RepeatIteratorVariable(