Revert "[dynamo][itertools] support itertools.tee (#133771)"

This reverts commit 1dbd3476de.

Reverted https://github.com/pytorch/pytorch/pull/133771 on behalf of https://github.com/ZainRizvi due to Sorry, have to revert this in order to be able to revert https://github.com/pytorch/pytorch/pull/133769 ([comment](https://github.com/pytorch/pytorch/pull/133771#issuecomment-2316611158))
This commit is contained in:
PyTorch MergeBot 2024-08-29 02:49:30 +00:00
parent eaec9e80b8
commit f65df5edae
3 changed files with 1 additions and 54 deletions

View File

@ -10061,22 +10061,6 @@ def ___make_guard_fn():
self.assertEqual(eager, compiled)
self.assertEqual(len(counters["graph_break"]), 0)
def test_itertools_tee(self):
counters.clear()
def fn(l):
a, b = itertools.tee(l)
return list(a), list(b)
l = [1, 2, 2, 3, 4, 4, 4, 1, 2]
eager = fn(l)
compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
compiled = compiled_fn(l)
self.assertEqual(eager, compiled)
self.assertEqual(len(counters["graph_break"]), 0)
def test_list_iterator_contains(self):
def fn(x):
it = iter(["my_weight", "not_my_weight"])

View File

@ -1,34 +0,0 @@
"""
Python polyfills for itertools
"""
import itertools
from typing import Iterable, Iterator, Tuple, TypeVar
from ..decorators import substitute_in_graph
__all__ = ["tee"]
_T = TypeVar("_T")
# Reference: https://docs.python.org/3/library/itertools.html#itertools.tee
@substitute_in_graph(itertools.tee)
def tee(iterable: Iterable[_T], n: int = 2, /) -> Tuple[Iterator[_T], ...]:
iterator = iter(iterable)
shared_link = [None, None]
def _tee(link) -> Iterator[_T]: # type: ignore[no-untyped-def]
try:
while True:
if link[1] is None:
link[0] = next(iterator)
link[1] = [None, None]
value, link = link
yield value
except StopIteration:
return
return tuple(_tee(shared_link) for _ in range(n))

View File

@ -11,10 +11,7 @@ if TYPE_CHECKING:
from types import ModuleType
POLYFILLED_MODULE_NAMES: Tuple[str, ...] = (
"builtins",
"itertools",
)
POLYFILLED_MODULE_NAMES: Tuple[str, ...] = ("builtins",)
POLYFILLED_MODULES: Tuple["ModuleType", ...] = tuple(
importlib.import_module(f".{submodule}", package=polyfills.__name__)
for submodule in POLYFILLED_MODULE_NAMES