[dynamo][itertools] support itertools.tee (#133771)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133771
Approved by: https://github.com/jansel
This commit is contained in:
Xuehai Pan 2024-08-27 03:56:10 +08:00 committed by PyTorch MergeBot
parent 43bbd781f2
commit 1dbd3476de
3 changed files with 54 additions and 1 deletions

View File

@ -10045,6 +10045,22 @@ def ___make_guard_fn():
self.assertEqual(eager, compiled)
self.assertEqual(len(counters["graph_break"]), 0)
def test_itertools_tee(self):
counters.clear()
def fn(l):
a, b = itertools.tee(l)
return list(a), list(b)
l = [1, 2, 2, 3, 4, 4, 4, 1, 2]
eager = fn(l)
compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
compiled = compiled_fn(l)
self.assertEqual(eager, compiled)
self.assertEqual(len(counters["graph_break"]), 0)
def test_list_iterator_contains(self):
def fn(x):
it = iter(["my_weight", "not_my_weight"])

View File

@ -0,0 +1,34 @@
"""
Python polyfills for itertools
"""
import itertools
from typing import Iterable, Iterator, Tuple, TypeVar
from ..decorators import substitute_in_graph
__all__ = ["tee"]
_T = TypeVar("_T")
# Reference: https://docs.python.org/3/library/itertools.html#itertools.tee
@substitute_in_graph(itertools.tee)
def tee(iterable: Iterable[_T], n: int = 2, /) -> Tuple[Iterator[_T], ...]:
iterator = iter(iterable)
shared_link = [None, None]
def _tee(link) -> Iterator[_T]: # type: ignore[no-untyped-def]
try:
while True:
if link[1] is None:
link[0] = next(iterator)
link[1] = [None, None]
value, link = link
yield value
except StopIteration:
return
return tuple(_tee(shared_link) for _ in range(n))

View File

@ -11,7 +11,10 @@ if TYPE_CHECKING:
from types import ModuleType
POLYFILLED_MODULE_NAMES: Tuple[str, ...] = ("builtins",)
POLYFILLED_MODULE_NAMES: Tuple[str, ...] = (
"builtins",
"itertools",
)
POLYFILLED_MODULES: Tuple["ModuleType", ...] = tuple(
importlib.import_module(f".{submodule}", package=polyfills.__name__)
for submodule in POLYFILLED_MODULE_NAMES