mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] simplify polyfill registration for builtins.all and builtins.any (#133769)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133769 Approved by: https://github.com/jansel
This commit is contained in:
parent
b977abd5de
commit
e09324e7da
|
|
@ -168,6 +168,7 @@ def forbid_in_graph(fn):
|
||||||
def substitute_in_graph(
|
def substitute_in_graph(
|
||||||
original_fn: _F,
|
original_fn: _F,
|
||||||
*,
|
*,
|
||||||
|
can_constant_fold_through: bool = False,
|
||||||
skip_signature_check: bool = False,
|
skip_signature_check: bool = False,
|
||||||
) -> Callable[[_F], _F]:
|
) -> Callable[[_F], _F]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -187,6 +188,10 @@ def substitute_in_graph(
|
||||||
Args:
|
Args:
|
||||||
original_fn (callable): The original function, usually a C function, to register a polyfill
|
original_fn (callable): The original function, usually a C function, to register a polyfill
|
||||||
handler for.
|
handler for.
|
||||||
|
can_constant_fold_through (bool, optional): Whether the polyfill handler can be constant
|
||||||
|
folded through. That is, if the polyfill handler is a pure function and its arguments
|
||||||
|
are constant, the result of the polyfill handler can be constant folded during the
|
||||||
|
compilation. Defaults to ``False``.
|
||||||
skip_signature_check (bool, optional): Whether to skip the signature check between the
|
skip_signature_check (bool, optional): Whether to skip the signature check between the
|
||||||
original function and the polyfill handler. Defaults to ``False``.
|
original function and the polyfill handler. Defaults to ``False``.
|
||||||
|
|
||||||
|
|
@ -319,6 +324,7 @@ def substitute_in_graph(
|
||||||
|
|
||||||
wrapped.__torch_dynamo_original__ = original_fn # type: ignore[attr-defined]
|
wrapped.__torch_dynamo_original__ = original_fn # type: ignore[attr-defined]
|
||||||
wrapped.__torch_dynamo_polyfill__ = traceable_fn # type: ignore[attr-defined]
|
wrapped.__torch_dynamo_polyfill__ = traceable_fn # type: ignore[attr-defined]
|
||||||
|
wrapped.__torch_dynamo_can_constant_fold_through__ = can_constant_fold_through # type: ignore[attr-defined]
|
||||||
|
|
||||||
return wrapped # type: ignore[return-value]
|
return wrapped # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,20 +13,6 @@ from typing import Any, Callable, Sequence
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def all(iterator):
|
|
||||||
for elem in iterator:
|
|
||||||
if not elem:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def any(iterator):
|
|
||||||
for elem in iterator:
|
|
||||||
if elem:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def index(iterator, item, start=0, end=None):
|
def index(iterator, item, start=0, end=None):
|
||||||
for i, elem in islice(enumerate(iterator), start, end):
|
for i, elem in islice(enumerate(iterator), start, end):
|
||||||
if item == elem:
|
if item == elem:
|
||||||
|
|
|
||||||
|
|
@ -2,5 +2,29 @@
|
||||||
Python polyfills for builtins
|
Python polyfills for builtins
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import builtins
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
__all__ = [] # type: ignore[var-annotated]
|
from ..decorators import substitute_in_graph
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"all",
|
||||||
|
"any",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@substitute_in_graph(builtins.all, can_constant_fold_through=True)
|
||||||
|
def all(iterable: Iterable[object], /) -> bool:
|
||||||
|
for elem in iterable:
|
||||||
|
if not elem:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@substitute_in_graph(builtins.any, can_constant_fold_through=True)
|
||||||
|
def any(iterable: Iterable[object], /) -> bool:
|
||||||
|
for elem in iterable:
|
||||||
|
if elem:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ __all__ = ["fspath"]
|
||||||
|
|
||||||
|
|
||||||
# Copied from os.py in the standard library
|
# Copied from os.py in the standard library
|
||||||
@substitute_in_graph(os.fspath)
|
@substitute_in_graph(os.fspath, can_constant_fold_through=True)
|
||||||
def fspath(path: AnyStr | os.PathLike[AnyStr]) -> AnyStr:
|
def fspath(path: AnyStr | os.PathLike[AnyStr]) -> AnyStr:
|
||||||
if isinstance(path, (str, bytes)):
|
if isinstance(path, (str, bytes)):
|
||||||
return path
|
return path
|
||||||
|
|
|
||||||
|
|
@ -2980,6 +2980,7 @@ def _disallowed_callable_ids() -> Dict[int, str]:
|
||||||
|
|
||||||
@FunctionIdSet
|
@FunctionIdSet
|
||||||
def _builtin_function_ids() -> Dict[int, str]:
|
def _builtin_function_ids() -> Dict[int, str]:
|
||||||
|
# See also torch/_dynamo/polyfills/loader.py, which removes items in _builtin_function_ids
|
||||||
rv = {
|
rv = {
|
||||||
id(v): f"builtins.{k}"
|
id(v): f"builtins.{k}"
|
||||||
for k, v in builtins.__dict__.items()
|
for k, v in builtins.__dict__.items()
|
||||||
|
|
@ -3072,6 +3073,7 @@ def is_forbidden(obj) -> bool:
|
||||||
|
|
||||||
|
|
||||||
def is_builtin_callable(obj) -> bool:
|
def is_builtin_callable(obj) -> bool:
|
||||||
|
# See also torch/_dynamo/polyfills/loader.py, which removes items in _builtin_function_ids
|
||||||
return id(obj) in _builtin_function_ids
|
return id(obj) in _builtin_function_ids
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ import torch
|
||||||
from torch import sym_float, sym_int
|
from torch import sym_float, sym_int
|
||||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||||
|
|
||||||
from .. import config, polyfills, variables
|
from .. import config, variables
|
||||||
from ..exc import (
|
from ..exc import (
|
||||||
AttributeMutationError,
|
AttributeMutationError,
|
||||||
unimplemented,
|
unimplemented,
|
||||||
|
|
@ -94,19 +94,6 @@ IN_PLACE_DESUGARING_MAP = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _polyfill_call_impl(name):
|
|
||||||
"""Create a BuiltinVariable.call_{name} method that inlines through polyfill.{name}"""
|
|
||||||
|
|
||||||
def call_fn(self, tx: "InstructionTranslator", *args, **kwargs):
|
|
||||||
return tx.inline_user_function_return(
|
|
||||||
variables.UserFunctionVariable(fn), args, kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
fn = getattr(polyfills, name)
|
|
||||||
call_fn.__name__ = f"call_{name}"
|
|
||||||
return call_fn
|
|
||||||
|
|
||||||
|
|
||||||
class BuiltinVariable(VariableTracker):
|
class BuiltinVariable(VariableTracker):
|
||||||
_SENTINEL = object()
|
_SENTINEL = object()
|
||||||
_nonvar_fields = {
|
_nonvar_fields = {
|
||||||
|
|
@ -2124,9 +2111,6 @@ class BuiltinVariable(VariableTracker):
|
||||||
):
|
):
|
||||||
return a.call_method(tx, "__contains__", [b], {})
|
return a.call_method(tx, "__contains__", [b], {})
|
||||||
|
|
||||||
call_all = _polyfill_call_impl("all")
|
|
||||||
call_any = _polyfill_call_impl("any")
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def dynamo_disable_grad(tx):
|
def dynamo_disable_grad(tx):
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ from ..guards import GuardBuilder, install_guard
|
||||||
from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
|
from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
check_constant_args,
|
check_constant_args,
|
||||||
|
check_unspec_or_constant_args,
|
||||||
identity,
|
identity,
|
||||||
is_wrapper_or_member_descriptor,
|
is_wrapper_or_member_descriptor,
|
||||||
istype,
|
istype,
|
||||||
|
|
@ -965,6 +966,15 @@ class PolyfilledFunctionVariable(VariableTracker):
|
||||||
handler = self._get_polyfill_handlers().get(self.fn)
|
handler = self._get_polyfill_handlers().get(self.fn)
|
||||||
if handler:
|
if handler:
|
||||||
assert callable(handler)
|
assert callable(handler)
|
||||||
|
if getattr(
|
||||||
|
handler, "__torch_dynamo_can_constant_fold_through__", False
|
||||||
|
) and check_unspec_or_constant_args(args, kwargs):
|
||||||
|
return ConstantVariable.create(
|
||||||
|
self.fn( # use the original function which is faster than the polyfill
|
||||||
|
*[x.as_python_constant() for x in args],
|
||||||
|
**{k: v.as_python_constant() for k, v in kwargs.items()},
|
||||||
|
)
|
||||||
|
)
|
||||||
return SourcelessBuilder.create(tx, handler).call_function(tx, args, kwargs)
|
return SourcelessBuilder.create(tx, handler).call_function(tx, args, kwargs)
|
||||||
|
|
||||||
for candidate in ("__torch_dynamo_polyfill__", "__python_implementation__"):
|
for candidate in ("__torch_dynamo_polyfill__", "__python_implementation__"):
|
||||||
|
|
|
||||||
|
|
@ -123,6 +123,7 @@ def allow_in_graph(fn):
|
||||||
def substitute_in_graph(
|
def substitute_in_graph(
|
||||||
original_fn: _F,
|
original_fn: _F,
|
||||||
*,
|
*,
|
||||||
|
can_constant_fold_through: bool = False,
|
||||||
skip_signature_check: bool = False,
|
skip_signature_check: bool = False,
|
||||||
) -> Callable[[_F], _F]:
|
) -> Callable[[_F], _F]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -142,6 +143,10 @@ def substitute_in_graph(
|
||||||
Args:
|
Args:
|
||||||
original_fn (callable): The original function, usually a C function, to register a polyfill
|
original_fn (callable): The original function, usually a C function, to register a polyfill
|
||||||
handler for.
|
handler for.
|
||||||
|
can_constant_fold_through (bool, optional): Whether the polyfill handler can be constant
|
||||||
|
folded through. That is, if the polyfill handler is a pure function and its arguments
|
||||||
|
are constant, the result of the polyfill handler can be constant folded during the
|
||||||
|
compilation. Defaults to ``False``.
|
||||||
skip_signature_check (bool, optional): Whether to skip the signature check between the
|
skip_signature_check (bool, optional): Whether to skip the signature check between the
|
||||||
original function and the polyfill handler. Defaults to ``False``.
|
original function and the polyfill handler. Defaults to ``False``.
|
||||||
|
|
||||||
|
|
@ -173,6 +178,7 @@ def substitute_in_graph(
|
||||||
|
|
||||||
return torch._dynamo.substitute_in_graph(
|
return torch._dynamo.substitute_in_graph(
|
||||||
original_fn,
|
original_fn,
|
||||||
|
can_constant_fold_through=can_constant_fold_through,
|
||||||
skip_signature_check=skip_signature_check,
|
skip_signature_check=skip_signature_check,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user