mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo][itertools] refactor itertools.chain and itertools.chain.from_iterable to use polyfills (#133864)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133864 Approved by: https://github.com/jansel
This commit is contained in:
parent
5dad6a5a84
commit
ebbdeeede1
|
|
@ -170,6 +170,8 @@ def substitute_in_graph(
|
||||||
*,
|
*,
|
||||||
can_constant_fold_through: bool = False,
|
can_constant_fold_through: bool = False,
|
||||||
skip_signature_check: bool = False,
|
skip_signature_check: bool = False,
|
||||||
|
# type that is embedded in the Python interpreter
|
||||||
|
is_embedded_type: bool = False, # internal use only
|
||||||
) -> Callable[[_F], _F]:
|
) -> Callable[[_F], _F]:
|
||||||
"""
|
"""
|
||||||
Register a polyfill handler for a function, usually a C function from the C extension, to be
|
Register a polyfill handler for a function, usually a C function from the C extension, to be
|
||||||
|
|
@ -219,10 +221,22 @@ def substitute_in_graph(
|
||||||
>>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)
|
>>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)
|
||||||
2
|
2
|
||||||
"""
|
"""
|
||||||
if not is_function(original_fn):
|
if not is_function(original_fn) and not (
|
||||||
|
is_embedded_type and inspect.isclass(original_fn)
|
||||||
|
):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"substitute_in_graph expects a function but got {type(original_fn)!r}"
|
f"substitute_in_graph expects a function but got {type(original_fn)!r}"
|
||||||
)
|
)
|
||||||
|
if is_embedded_type:
|
||||||
|
if not inspect.isclass(original_fn):
|
||||||
|
raise TypeError(
|
||||||
|
f"substitute_in_graph expects a class but got {type(original_fn)!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
from .variables.builder import ITERTOOLS_POLYFILLED_TYPE_IDS, ITERTOOLS_TYPE_IDS
|
||||||
|
|
||||||
|
if id(original_fn) in ITERTOOLS_TYPE_IDS:
|
||||||
|
ITERTOOLS_POLYFILLED_TYPE_IDS.add(id(original_fn))
|
||||||
|
|
||||||
def wrapper(traceable_fn: _F) -> _F:
|
def wrapper(traceable_fn: _F) -> _F:
|
||||||
if not is_function(traceable_fn):
|
if not is_function(traceable_fn):
|
||||||
|
|
|
||||||
|
|
@ -4,15 +4,28 @@ Python polyfills for common builtins.
|
||||||
|
|
||||||
# NOTE: 1. Please do not import any submodule in the directory here to avoid circular imports.
|
# NOTE: 1. Please do not import any submodule in the directory here to avoid circular imports.
|
||||||
# 2. While adding a new polyfill module, also add it to POLYFILLED_MODULE_NAMES in loader.py.
|
# 2. While adding a new polyfill module, also add it to POLYFILLED_MODULE_NAMES in loader.py.
|
||||||
|
# Add it in the TYPE_CHECKING block below as well.
|
||||||
|
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
|
|
||||||
import math
|
from typing import Any, Callable, Sequence, TYPE_CHECKING
|
||||||
from typing import Any, Callable, Sequence
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
# Load by torch._dynamo.polyfills.loader
|
||||||
|
# See also the POLYFILLED_MODULE_NAMES in torch/_dynamo/polyfills/loader.py
|
||||||
|
# Put the submodules here to avoid circular imports
|
||||||
|
from . import (
|
||||||
|
builtins as builtins,
|
||||||
|
functools as functools,
|
||||||
|
itertools as itertools,
|
||||||
|
os as os,
|
||||||
|
sys as sys,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def index(iterator, item, start=0, end=None):
|
def index(iterator, item, start=0, end=None):
|
||||||
for i, elem in islice(enumerate(iterator), start, end):
|
for i, elem in islice(enumerate(iterator), start, end):
|
||||||
if item == elem:
|
if item == elem:
|
||||||
|
|
@ -50,6 +63,8 @@ def repeat(item, count):
|
||||||
|
|
||||||
|
|
||||||
def radians(x):
|
def radians(x):
|
||||||
|
import math
|
||||||
|
|
||||||
return math.pi / 180.0 * x
|
return math.pi / 180.0 * x
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,12 +10,31 @@ from typing import Iterable, Iterator, TypeVar
|
||||||
from ..decorators import substitute_in_graph
|
from ..decorators import substitute_in_graph
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["tee"]
|
__all__ = [
|
||||||
|
"chain",
|
||||||
|
"chain_from_iterable",
|
||||||
|
"tee",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
|
||||||
|
# Reference: https://docs.python.org/3/library/itertools.html#itertools.chain
|
||||||
|
@substitute_in_graph(itertools.chain, is_embedded_type=True) # type: ignore[arg-type]
|
||||||
|
def chain(*iterables: Iterable[_T]) -> Iterator[_T]:
|
||||||
|
for iterable in iterables:
|
||||||
|
yield from iterable
|
||||||
|
|
||||||
|
|
||||||
|
@substitute_in_graph(itertools.chain.from_iterable) # type: ignore[arg-type]
|
||||||
|
def chain_from_iterable(iterable: Iterable[Iterable[_T]], /) -> Iterator[_T]:
|
||||||
|
return itertools.chain(*iterable)
|
||||||
|
|
||||||
|
|
||||||
|
chain.from_iterable = chain_from_iterable # type: ignore[method-assign]
|
||||||
|
|
||||||
|
|
||||||
# Reference: https://docs.python.org/3/library/itertools.html#itertools.tee
|
# Reference: https://docs.python.org/3/library/itertools.html#itertools.tee
|
||||||
@substitute_in_graph(itertools.tee)
|
@substitute_in_graph(itertools.tee)
|
||||||
def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]:
|
def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]:
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ if TYPE_CHECKING:
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
|
|
||||||
|
|
||||||
|
# See also the TYPE_CHECKING block in torch/_dynamo/polyfills/__init__.py
|
||||||
POLYFILLED_MODULE_NAMES: Tuple[str, ...] = (
|
POLYFILLED_MODULE_NAMES: Tuple[str, ...] = (
|
||||||
"builtins",
|
"builtins",
|
||||||
"functools",
|
"functools",
|
||||||
|
|
|
||||||
|
|
@ -2993,9 +2993,7 @@ def _builtin_function_ids() -> Dict[int, str]:
|
||||||
if not k.startswith("_") and callable(v)
|
if not k.startswith("_") and callable(v)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
rv.update(
|
rv.update({id(v): f"itertools.{v.__name__}" for v in (itertools.islice,)})
|
||||||
{id(v): f"itertools.{v.__name__}" for v in (itertools.chain, itertools.islice)}
|
|
||||||
)
|
|
||||||
rv.update(
|
rv.update(
|
||||||
{
|
{
|
||||||
id(cast): "typing.cast",
|
id(cast): "typing.cast",
|
||||||
|
|
@ -3474,9 +3472,7 @@ def check_verbose(obj, is_inlined_call=False):
|
||||||
|
|
||||||
# Consulte the central trace rules defined in torch._dynamo.trace_rules.
|
# Consulte the central trace rules defined in torch._dynamo.trace_rules.
|
||||||
reasons: Set[str] = set()
|
reasons: Set[str] = set()
|
||||||
rule = torch._dynamo.trace_rules.lookup_inner(
|
rule = lookup_inner(fi.py_obj, fi.name, fi.filename, is_inlined_call, reasons)
|
||||||
fi.py_obj, fi.name, fi.filename, is_inlined_call, reasons
|
|
||||||
)
|
|
||||||
if issubclass(rule, (UserFunctionVariable, PolyfilledFunctionVariable)):
|
if issubclass(rule, (UserFunctionVariable, PolyfilledFunctionVariable)):
|
||||||
return SkipResult(
|
return SkipResult(
|
||||||
False,
|
False,
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,17 @@ import sys
|
||||||
import types
|
import types
|
||||||
import warnings
|
import warnings
|
||||||
import weakref
|
import weakref
|
||||||
from typing import Any, List, MutableMapping, NamedTuple, Optional, TYPE_CHECKING, Union
|
from typing import (
|
||||||
|
Any,
|
||||||
|
FrozenSet,
|
||||||
|
List,
|
||||||
|
MutableMapping,
|
||||||
|
NamedTuple,
|
||||||
|
Optional,
|
||||||
|
Set,
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import SymInt
|
from torch import SymInt
|
||||||
|
|
@ -320,6 +330,17 @@ class FrameStateSizeEntry:
|
||||||
stride: Optional[List[int]]
|
stride: Optional[List[int]]
|
||||||
|
|
||||||
|
|
||||||
|
# All class-based iterators in itertools
|
||||||
|
# NOTE: use id() because some objects are not hashable, it will raise error during lookup
|
||||||
|
ITERTOOLS_TYPE_IDS: FrozenSet[int] = frozenset(
|
||||||
|
id(member)
|
||||||
|
for name, member in vars(itertools).items()
|
||||||
|
if not name.startswith("_") and inspect.isclass(member)
|
||||||
|
)
|
||||||
|
# Will be updated later in substitute_in_graph in torch/_dynamo/polyfills/itertools.py
|
||||||
|
ITERTOOLS_POLYFILLED_TYPE_IDS: Set[int] = set()
|
||||||
|
|
||||||
|
|
||||||
class VariableBuilder:
|
class VariableBuilder:
|
||||||
"""Wrap a python value in a VariableTracker() instance"""
|
"""Wrap a python value in a VariableTracker() instance"""
|
||||||
|
|
||||||
|
|
@ -875,7 +896,10 @@ class VariableBuilder:
|
||||||
value,
|
value,
|
||||||
source=self.source,
|
source=self.source,
|
||||||
)
|
)
|
||||||
elif istype(value, type) and value in itertools.__dict__.values():
|
elif (
|
||||||
|
id(value) in ITERTOOLS_TYPE_IDS
|
||||||
|
and id(value) not in ITERTOOLS_POLYFILLED_TYPE_IDS
|
||||||
|
):
|
||||||
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||||
return ItertoolsVariable(value, source=self.source)
|
return ItertoolsVariable(value, source=self.source)
|
||||||
elif isinstance(value, torch.SymBool):
|
elif isinstance(value, torch.SymBool):
|
||||||
|
|
|
||||||
|
|
@ -994,15 +994,6 @@ class BuiltinVariable(VariableTracker):
|
||||||
)
|
)
|
||||||
if self.fn is dict and name == "fromkeys":
|
if self.fn is dict and name == "fromkeys":
|
||||||
return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs)
|
return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs)
|
||||||
if self.fn is itertools.chain and name == "from_iterable":
|
|
||||||
assert len(args) == 1
|
|
||||||
assert len(kwargs) == 0
|
|
||||||
obj = args[0]
|
|
||||||
items = []
|
|
||||||
for item in obj.unpack_var_sequence(tx):
|
|
||||||
items.extend(item.unpack_var_sequence(tx))
|
|
||||||
return variables.TupleVariable(items)
|
|
||||||
|
|
||||||
return super().call_method(tx, name, args, kwargs)
|
return super().call_method(tx, name, args, kwargs)
|
||||||
|
|
||||||
def _call_int_float(self, tx: "InstructionTranslator", arg):
|
def _call_int_float(self, tx: "InstructionTranslator", arg):
|
||||||
|
|
@ -1942,13 +1933,6 @@ class BuiltinVariable(VariableTracker):
|
||||||
)
|
)
|
||||||
return variables.ListVariable(items)
|
return variables.ListVariable(items)
|
||||||
|
|
||||||
def call_chain(self, tx: "InstructionTranslator", *args):
|
|
||||||
if all(obj.has_unpack_var_sequence(tx) for obj in args):
|
|
||||||
items = []
|
|
||||||
for obj in args:
|
|
||||||
items.extend(obj.unpack_var_sequence(tx))
|
|
||||||
return variables.TupleVariable(items)
|
|
||||||
|
|
||||||
def call_islice(self, tx: "InstructionTranslator", iterable, *args):
|
def call_islice(self, tx: "InstructionTranslator", iterable, *args):
|
||||||
if iterable.has_unpack_var_sequence(tx) and all(
|
if iterable.has_unpack_var_sequence(tx) and all(
|
||||||
x.is_python_constant() for x in args
|
x.is_python_constant() for x in args
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ from ..utils import (
|
||||||
check_constant_args,
|
check_constant_args,
|
||||||
check_unspec_or_constant_args,
|
check_unspec_or_constant_args,
|
||||||
identity,
|
identity,
|
||||||
|
is_function,
|
||||||
is_wrapper_or_member_descriptor,
|
is_wrapper_or_member_descriptor,
|
||||||
istype,
|
istype,
|
||||||
make_cell,
|
make_cell,
|
||||||
|
|
@ -992,6 +993,27 @@ class PolyfilledFunctionVariable(VariableTracker):
|
||||||
handler,
|
handler,
|
||||||
).call_function(tx, args, kwargs)
|
).call_function(tx, args, kwargs)
|
||||||
|
|
||||||
|
return super().call_function(tx, args, kwargs)
|
||||||
|
|
||||||
|
def call_method(
|
||||||
|
self,
|
||||||
|
tx,
|
||||||
|
name,
|
||||||
|
args: "List[VariableTracker]",
|
||||||
|
kwargs: "Dict[str, VariableTracker]",
|
||||||
|
) -> "VariableTracker":
|
||||||
|
if name == "__call__":
|
||||||
|
return self.call_function(tx, args, kwargs)
|
||||||
|
|
||||||
|
method = getattr(self.fn, name, None)
|
||||||
|
assert method is not None, f"Member {name} not found in {self.fn}"
|
||||||
|
assert is_function(method), f"Member {name} is not callable in {self.fn}"
|
||||||
|
options = {}
|
||||||
|
if self.source:
|
||||||
|
options["source"] = AttrSource(self.source, name)
|
||||||
|
member_variable = PolyfilledFunctionVariable(method, **options)
|
||||||
|
return member_variable.call_function(tx, args, kwargs)
|
||||||
|
|
||||||
def as_python_constant(self):
|
def as_python_constant(self):
|
||||||
return self.fn
|
return self.fn
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -52,14 +52,6 @@ class ItertoolsVariable(VariableTracker):
|
||||||
for item in itertools.product(*seqs):
|
for item in itertools.product(*seqs):
|
||||||
items.append(variables.TupleVariable(list(item)))
|
items.append(variables.TupleVariable(list(item)))
|
||||||
return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
|
return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
|
||||||
elif (
|
|
||||||
self.value is itertools.chain
|
|
||||||
and not kwargs
|
|
||||||
and all(arg.has_unpack_var_sequence(tx) for arg in args)
|
|
||||||
):
|
|
||||||
seqs = [arg.unpack_var_sequence(tx) for arg in args]
|
|
||||||
items = list(itertools.chain.from_iterable(seqs))
|
|
||||||
return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
|
|
||||||
elif self.value is itertools.accumulate:
|
elif self.value is itertools.accumulate:
|
||||||
from .builtin import BuiltinVariable
|
from .builtin import BuiltinVariable
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user