diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 69145ce86d8..9b465726754 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -14,15 +14,36 @@ import torch def index(iterator, item, start=0, end=None): - import itertools - - for i, elem in itertools.islice(enumerate(iterator), start, end): + for i, elem in islice(enumerate(iterator), start, end): if item == elem: return i # This will not run in dynamo raise ValueError(f"{item} is not in {type(iterator)}") +def islice(iterator, start=0, end=None, step=1): + if start < 0 or (end is not None and end < 0) or step < 0: + raise ValueError("Indices must be non-negative") + if step == 0: + raise ValueError("Step cannot be 0") + + it = iter(iterator) + + for _ in range(start): + next(it) + + if end is None: + for i, element in enumerate(it): + if i % step == 0: + yield element + else: + for i, element in enumerate(it): + if i % step == 0 and i + start < end - start: + yield element + elif i + start >= end - start: + break + + def repeat(item, count): for i in range(count): yield item diff --git a/torch/_dynamo/polyfills/itertools.py b/torch/_dynamo/polyfills/itertools.py index a829d2e9b88..802a62a82c8 100644 --- a/torch/_dynamo/polyfills/itertools.py +++ b/torch/_dynamo/polyfills/itertools.py @@ -13,7 +13,6 @@ from ..decorators import substitute_in_graph __all__ = [ "chain", "chain_from_iterable", - "islice", "tee", ] @@ -36,35 +35,6 @@ def chain_from_iterable(iterable: Iterable[Iterable[_T]], /) -> Iterator[_T]: chain.from_iterable = chain_from_iterable # type: ignore[method-assign] -# Reference: https://docs.python.org/3/library/itertools.html#itertools.islice -@substitute_in_graph(itertools.islice, is_embedded_type=True) # type: ignore[arg-type] -def islice(iterable: Iterable[_T], /, *args: int | None) -> Iterator[_T]: - s = slice(*args) - start = 0 if s.start is None else s.start - stop = s.stop - step = 1 if s.step is None else s.step - if start < 0 or (stop is not None and stop < 0) or step <= 0: - raise ValueError( - "Indices for islice() must be None or an integer: 0 <= x <= sys.maxsize.", - ) - - if stop is None: - # TODO: use indices = itertools.count() and merge implementation with the else branch - # when we support infinite iterators - next_i = start - for i, element in enumerate(iterable): - if i == next_i: - yield element - next_i += step - else: - indices = range(max(start, stop)) - next_i = start - for i, element in zip(indices, iterable): - if i == next_i: - yield element - next_i += step - - # Reference: https://docs.python.org/3/library/itertools.html#itertools.tee @substitute_in_graph(itertools.tee) def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]: diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 5535273b7dc..29ad4847f27 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -12,6 +12,7 @@ import enum import functools import importlib import inspect +import itertools import linecache import logging import multiprocessing @@ -2992,6 +2993,7 @@ def _builtin_function_ids() -> Dict[int, str]: if not k.startswith("_") and callable(v) } ) + rv.update({id(v): f"itertools.{v.__name__}" for v in (itertools.islice,)}) rv.update( { id(cast): "typing.cast", diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 919daf6fbd2..279981465b2 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1873,6 +1873,15 @@ class BuiltinVariable(VariableTracker): ) return variables.ListVariable(items) + def call_islice(self, tx: "InstructionTranslator", iterable, *args): + if iterable.has_unpack_var_sequence(tx) and all( + x.is_python_constant() for x in args + ): + const_args = [x.as_python_constant() for x in args] + items = iterable.unpack_var_sequence(tx) + items = list(itertools.islice(items, *const_args)) + return variables.TupleVariable(items) + # neg is a constant fold function, so we only get here if constant fold is not valid def call_neg(self, tx: "InstructionTranslator", a): if isinstance(a, SymNodeVariable): diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index 34c2354026a..ed3ae786634 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -166,6 +166,12 @@ class ItertoolsVariable(VariableTracker): from_exc=e, ) return variables.ListIteratorVariable(result, mutable_local=MutableLocal()) + elif self.value is itertools.islice: + from .builder import SourcelessBuilder + + return tx.inline_user_function_return( + SourcelessBuilder.create(tx, polyfills.islice), args, kwargs + ) elif self.value is itertools.repeat: if len(args) < 2: return variables.RepeatIteratorVariable(