mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo][itertools] support itertools.tee (#133771)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133771 Approved by: https://github.com/jansel
This commit is contained in:
parent
43bbd781f2
commit
1dbd3476de
|
|
@ -10045,6 +10045,22 @@ def ___make_guard_fn():
|
|||
self.assertEqual(eager, compiled)
|
||||
self.assertEqual(len(counters["graph_break"]), 0)
|
||||
|
||||
def test_itertools_tee(self):
|
||||
counters.clear()
|
||||
|
||||
def fn(l):
|
||||
a, b = itertools.tee(l)
|
||||
return list(a), list(b)
|
||||
|
||||
l = [1, 2, 2, 3, 4, 4, 4, 1, 2]
|
||||
eager = fn(l)
|
||||
|
||||
compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
|
||||
compiled = compiled_fn(l)
|
||||
|
||||
self.assertEqual(eager, compiled)
|
||||
self.assertEqual(len(counters["graph_break"]), 0)
|
||||
|
||||
def test_list_iterator_contains(self):
|
||||
def fn(x):
|
||||
it = iter(["my_weight", "not_my_weight"])
|
||||
|
|
|
|||
34
torch/_dynamo/polyfills/itertools.py
Normal file
34
torch/_dynamo/polyfills/itertools.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
"""
|
||||
Python polyfills for itertools
|
||||
"""
|
||||
|
||||
import itertools
|
||||
from typing import Iterable, Iterator, Tuple, TypeVar
|
||||
|
||||
from ..decorators import substitute_in_graph
|
||||
|
||||
|
||||
__all__ = ["tee"]
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
# Reference: https://docs.python.org/3/library/itertools.html#itertools.tee
|
||||
@substitute_in_graph(itertools.tee)
|
||||
def tee(iterable: Iterable[_T], n: int = 2, /) -> Tuple[Iterator[_T], ...]:
|
||||
iterator = iter(iterable)
|
||||
shared_link = [None, None]
|
||||
|
||||
def _tee(link) -> Iterator[_T]: # type: ignore[no-untyped-def]
|
||||
try:
|
||||
while True:
|
||||
if link[1] is None:
|
||||
link[0] = next(iterator)
|
||||
link[1] = [None, None]
|
||||
value, link = link
|
||||
yield value
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
return tuple(_tee(shared_link) for _ in range(n))
|
||||
|
|
@ -11,7 +11,10 @@ if TYPE_CHECKING:
|
|||
from types import ModuleType
|
||||
|
||||
|
||||
POLYFILLED_MODULE_NAMES: Tuple[str, ...] = ("builtins",)
|
||||
POLYFILLED_MODULE_NAMES: Tuple[str, ...] = (
|
||||
"builtins",
|
||||
"itertools",
|
||||
)
|
||||
POLYFILLED_MODULES: Tuple["ModuleType", ...] = tuple(
|
||||
importlib.import_module(f".{submodule}", package=polyfills.__name__)
|
||||
for submodule in POLYFILLED_MODULE_NAMES
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user