mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo][itertools] refactor itertools.islice to use polyfill (#133876)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133876 Approved by: https://github.com/jansel ghstack dependencies: #133864, #133894
This commit is contained in:
parent
ec660c383e
commit
eed0d76682
|
|
@ -27,6 +27,8 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
|
|
||||||
def index(iterator, item, start=0, end=None):
|
def index(iterator, item, start=0, end=None):
|
||||||
|
from itertools import islice
|
||||||
|
|
||||||
for i, elem in islice(enumerate(iterator), start, end):
|
for i, elem in islice(enumerate(iterator), start, end):
|
||||||
if item == elem:
|
if item == elem:
|
||||||
return i
|
return i
|
||||||
|
|
@ -34,29 +36,6 @@ def index(iterator, item, start=0, end=None):
|
||||||
raise ValueError(f"{item} is not in {type(iterator)}")
|
raise ValueError(f"{item} is not in {type(iterator)}")
|
||||||
|
|
||||||
|
|
||||||
def islice(iterator, start=0, end=None, step=1):
|
|
||||||
if start < 0 or (end is not None and end < 0) or step < 0:
|
|
||||||
raise ValueError("Indices must be non-negative")
|
|
||||||
if step == 0:
|
|
||||||
raise ValueError("Step cannot be 0")
|
|
||||||
|
|
||||||
it = iter(iterator)
|
|
||||||
|
|
||||||
for _ in range(start):
|
|
||||||
next(it)
|
|
||||||
|
|
||||||
if end is None:
|
|
||||||
for i, element in enumerate(it):
|
|
||||||
if i % step == 0:
|
|
||||||
yield element
|
|
||||||
else:
|
|
||||||
for i, element in enumerate(it):
|
|
||||||
if i % step == 0 and i + start < end - start:
|
|
||||||
yield element
|
|
||||||
elif i + start >= end - start:
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
def repeat(item, count):
|
def repeat(item, count):
|
||||||
for i in range(count):
|
for i in range(count):
|
||||||
yield item
|
yield item
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ from ..decorators import substitute_in_graph
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"chain",
|
"chain",
|
||||||
"chain_from_iterable",
|
"chain_from_iterable",
|
||||||
|
"islice",
|
||||||
"tee",
|
"tee",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -35,6 +36,35 @@ def chain_from_iterable(iterable: Iterable[Iterable[_T]], /) -> Iterator[_T]:
|
||||||
chain.from_iterable = chain_from_iterable # type: ignore[method-assign]
|
chain.from_iterable = chain_from_iterable # type: ignore[method-assign]
|
||||||
|
|
||||||
|
|
||||||
|
# Reference: https://docs.python.org/3/library/itertools.html#itertools.islice
|
||||||
|
@substitute_in_graph(itertools.islice, is_embedded_type=True) # type: ignore[arg-type]
|
||||||
|
def islice(iterable: Iterable[_T], /, *args: int | None) -> Iterator[_T]:
|
||||||
|
s = slice(*args)
|
||||||
|
start = 0 if s.start is None else s.start
|
||||||
|
stop = s.stop
|
||||||
|
step = 1 if s.step is None else s.step
|
||||||
|
if start < 0 or (stop is not None and stop < 0) or step <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Indices for islice() must be None or an integer: 0 <= x <= sys.maxsize.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if stop is None:
|
||||||
|
# TODO: use indices = itertools.count() and merge implementation with the else branch
|
||||||
|
# when we support infinite iterators
|
||||||
|
next_i = start
|
||||||
|
for i, element in enumerate(iterable):
|
||||||
|
if i == next_i:
|
||||||
|
yield element
|
||||||
|
next_i += step
|
||||||
|
else:
|
||||||
|
indices = range(max(start, stop))
|
||||||
|
next_i = start
|
||||||
|
for i, element in zip(indices, iterable):
|
||||||
|
if i == next_i:
|
||||||
|
yield element
|
||||||
|
next_i += step
|
||||||
|
|
||||||
|
|
||||||
# Reference: https://docs.python.org/3/library/itertools.html#itertools.tee
|
# Reference: https://docs.python.org/3/library/itertools.html#itertools.tee
|
||||||
@substitute_in_graph(itertools.tee)
|
@substitute_in_graph(itertools.tee)
|
||||||
def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]:
|
def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]:
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,6 @@ import enum
|
||||||
import functools
|
import functools
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import itertools
|
|
||||||
import linecache
|
import linecache
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
|
@ -2993,7 +2992,6 @@ def _builtin_function_ids() -> Dict[int, str]:
|
||||||
if not k.startswith("_") and callable(v)
|
if not k.startswith("_") and callable(v)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
rv.update({id(v): f"itertools.{v.__name__}" for v in (itertools.islice,)})
|
|
||||||
rv.update(
|
rv.update(
|
||||||
{
|
{
|
||||||
id(cast): "typing.cast",
|
id(cast): "typing.cast",
|
||||||
|
|
|
||||||
|
|
@ -1917,15 +1917,6 @@ class BuiltinVariable(VariableTracker):
|
||||||
)
|
)
|
||||||
return variables.ListVariable(items)
|
return variables.ListVariable(items)
|
||||||
|
|
||||||
def call_islice(self, tx: "InstructionTranslator", iterable, *args):
|
|
||||||
if iterable.has_unpack_var_sequence(tx) and all(
|
|
||||||
x.is_python_constant() for x in args
|
|
||||||
):
|
|
||||||
const_args = [x.as_python_constant() for x in args]
|
|
||||||
items = iterable.unpack_var_sequence(tx)
|
|
||||||
items = list(itertools.islice(items, *const_args))
|
|
||||||
return variables.TupleVariable(items)
|
|
||||||
|
|
||||||
# neg is a constant fold function, so we only get here if constant fold is not valid
|
# neg is a constant fold function, so we only get here if constant fold is not valid
|
||||||
def call_neg(self, tx: "InstructionTranslator", a):
|
def call_neg(self, tx: "InstructionTranslator", a):
|
||||||
if isinstance(a, SymNodeVariable):
|
if isinstance(a, SymNodeVariable):
|
||||||
|
|
|
||||||
|
|
@ -166,12 +166,6 @@ class ItertoolsVariable(VariableTracker):
|
||||||
from_exc=e,
|
from_exc=e,
|
||||||
)
|
)
|
||||||
return variables.ListIteratorVariable(result, mutable_local=MutableLocal())
|
return variables.ListIteratorVariable(result, mutable_local=MutableLocal())
|
||||||
elif self.value is itertools.islice:
|
|
||||||
from .builder import SourcelessBuilder
|
|
||||||
|
|
||||||
return tx.inline_user_function_return(
|
|
||||||
SourcelessBuilder.create(tx, polyfills.islice), args, kwargs
|
|
||||||
)
|
|
||||||
elif self.value is itertools.repeat:
|
elif self.value is itertools.repeat:
|
||||||
if len(args) < 2:
|
if len(args) < 2:
|
||||||
return variables.RepeatIteratorVariable(
|
return variables.RepeatIteratorVariable(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user