mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fixes #142306 This PR includes typing improvements and refactoring for the following files: - __init__.py - decorators.py - _ops.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/144047 Approved by: https://github.com/XuehaiPan, https://github.com/Skylion007 Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com> Co-authored-by: Xuehai Pan <XuehaiPan@pku.edu.cn>
This commit is contained in:
parent
9225f149eb
commit
d4609af1ca
|
|
@ -5,6 +5,7 @@ import inspect
|
|||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, Type, TYPE_CHECKING, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
from torch._environment import is_fbcode
|
||||
|
|
@ -45,7 +46,8 @@ else:
|
|||
globals()[name] = getattr(torch._C._dynamo.eval_frame, name)
|
||||
|
||||
|
||||
_F = TypeVar("_F", bound=Callable[..., Any])
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
|
||||
def run(fn=None):
|
||||
|
|
@ -219,13 +221,13 @@ def forbid_in_graph(fn):
|
|||
|
||||
|
||||
def substitute_in_graph(
|
||||
original_fn: _F,
|
||||
original_fn: Callable[_P, _R],
|
||||
*,
|
||||
can_constant_fold_through: bool = False,
|
||||
skip_signature_check: bool = False,
|
||||
# type that is embedded in the Python interpreter
|
||||
is_embedded_type: bool = False, # internal use only
|
||||
) -> Callable[[_F], _F]:
|
||||
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
|
||||
"""
|
||||
Register a polyfill handler for a function, usually a C function from the C extension, to be
|
||||
used in place of the original function when inlining the original function in the graph.
|
||||
|
|
@ -291,7 +293,7 @@ def substitute_in_graph(
|
|||
if id(original_fn) in ITERTOOLS_TYPE_IDS:
|
||||
ITERTOOLS_POLYFILLED_TYPE_IDS.add(id(original_fn))
|
||||
|
||||
def wrapper(traceable_fn: _F) -> _F:
|
||||
def wrapper(traceable_fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
if not is_function(traceable_fn):
|
||||
raise TypeError(
|
||||
f"@substitute_in_graph(...) expects a function but got {type(traceable_fn)!r}"
|
||||
|
|
@ -382,10 +384,10 @@ def substitute_in_graph(
|
|||
# Need to wrap the function because we may cannot assign __torch_dynamo_polyfill__ to a
|
||||
# C++ function.
|
||||
@functools.wraps(traceable_fn)
|
||||
def wrapped(*args, **kwargs):
|
||||
def wrapped(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
return original_fn(*args, **kwargs)
|
||||
|
||||
def dispatch_fn(self, value: _F) -> PolyfilledFunctionVariable:
|
||||
def dispatch_fn(self, value: Callable[_P, _R]) -> PolyfilledFunctionVariable:
|
||||
return PolyfilledFunctionVariable(
|
||||
value,
|
||||
source=self.source,
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ def chain_from_iterable(iterable: Iterable[Iterable[_T]], /) -> Iterator[_T]:
|
|||
return itertools.chain(*iterable)
|
||||
|
||||
|
||||
chain.from_iterable = chain_from_iterable # type: ignore[method-assign]
|
||||
chain.from_iterable = chain_from_iterable # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# Reference: https://docs.python.org/3/library/itertools.html#itertools.compress
|
||||
|
|
|
|||
|
|
@ -6,24 +6,36 @@ import importlib
|
|||
import inspect
|
||||
import sys
|
||||
import types
|
||||
from typing import Any, Callable, Dict, List, Set, Type, TypeVar, Union
|
||||
from typing_extensions import ParamSpec
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch import _utils_internal
|
||||
from torch._C import _dispatch_is_included_in_alias as is_included_in_alias, DispatchKey
|
||||
from torch._functorch.pyfunctorch import dispatch_functorch
|
||||
from torch._functorch.pyfunctorch import dispatch_functorch, TransformType
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._subclasses.functional_tensor import BaseFunctionalizeAPI
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
||||
_F = TypeVar("_F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
# Query `hasattr` only once.
|
||||
_SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags")
|
||||
|
||||
|
|
@ -112,11 +124,11 @@ class OperatorBase:
|
|||
k: Union[
|
||||
Type[TorchDispatchMode],
|
||||
Type[torch.Tensor],
|
||||
torch._C._functorch.TransformType,
|
||||
torch._C.DispatchKey,
|
||||
TransformType,
|
||||
DispatchKey,
|
||||
],
|
||||
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
||||
def inner(fn: _F) -> _F:
|
||||
def inner(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
if inspect.isclass(k) and (
|
||||
issubclass(k, TorchDispatchMode) or issubclass(k, torch.Tensor)
|
||||
):
|
||||
|
|
@ -126,7 +138,7 @@ class OperatorBase:
|
|||
self._dispatch_cache.clear()
|
||||
return fn
|
||||
|
||||
if isinstance(k, torch._C._functorch.TransformType):
|
||||
if isinstance(k, TransformType):
|
||||
assert k not in self.functorch_table
|
||||
self.functorch_table[k] = fn
|
||||
return fn
|
||||
|
|
@ -134,7 +146,7 @@ class OperatorBase:
|
|||
assert isinstance(k, DispatchKey)
|
||||
assert (
|
||||
k != DispatchKey.Python
|
||||
), "Please register a mode for the torch._C.DispatchKey.Python key instead."
|
||||
), "Please register a mode for the DispatchKey.Python key instead."
|
||||
|
||||
if k in self.py_kernels:
|
||||
raise RuntimeError(
|
||||
|
|
@ -157,31 +169,34 @@ class OperatorBase:
|
|||
# with ctx.redispatch_to_next():
|
||||
# out = ctx.functionalize(inner_f)(*args_unwrapped)
|
||||
# return ctx.wrap_tensors(out)
|
||||
def py_functionalize_impl(self, fn: _F) -> _F:
|
||||
def py_functionalize_impl(
|
||||
self, fn: Callable[Concatenate["BaseFunctionalizeAPI", _P], _T]
|
||||
) -> Callable[Concatenate["BaseFunctionalizeAPI", _P], _T]:
|
||||
from torch._subclasses.functional_tensor import (
|
||||
CppFunctionalizeAPI as _CppFunctionalizeAPI,
|
||||
FunctorchFunctionalizeAPI as _FunctorchFunctionalizeAPI,
|
||||
PythonFunctionalizeAPI as _PythonFunctionalizeAPI,
|
||||
CppFunctionalizeAPI,
|
||||
FunctionalTensorMode,
|
||||
FunctorchFunctionalizeAPI,
|
||||
PythonFunctionalizeAPI,
|
||||
)
|
||||
|
||||
# Construct our three flavors of functionalization,
|
||||
# each of which have slightly different wrap/unwrap/redispatch policies
|
||||
def functionalize_dk_fn(*args, **kwargs):
|
||||
return fn(_CppFunctionalizeAPI(), *args, **kwargs)
|
||||
def functionalize_dk_fn(*args: _P.args, **kwargs: _P.kwargs) -> _T:
|
||||
return fn(CppFunctionalizeAPI(), *args, **kwargs)
|
||||
|
||||
def functionalize_dispatch_mode_fn(mode, *args, **kwargs):
|
||||
return fn(_PythonFunctionalizeAPI(mode), *args, **kwargs)
|
||||
def functionalize_dispatch_mode_fn(
|
||||
mode: Optional[FunctionalTensorMode], *args: _P.args, **kwargs: _P.kwargs
|
||||
) -> _T:
|
||||
return fn(PythonFunctionalizeAPI(mode), *args, **kwargs)
|
||||
|
||||
def functionalize_functorch_fn(interpreter, *args, **kwargs):
|
||||
return fn(_FunctorchFunctionalizeAPI(interpreter), *args, **kwargs)
|
||||
def functionalize_functorch_fn(
|
||||
interpreter, *args: _P.args, **kwargs: _P.kwargs
|
||||
) -> _T:
|
||||
return fn(FunctorchFunctionalizeAPI(interpreter), *args, **kwargs)
|
||||
|
||||
self.py_impl(DispatchKey.Functionalize)(functionalize_dk_fn)
|
||||
self.py_impl(torch._subclasses.functional_tensor.FunctionalTensorMode)(
|
||||
functionalize_dispatch_mode_fn
|
||||
)
|
||||
self.py_impl(torch._C._functorch.TransformType.Functionalize)(
|
||||
functionalize_functorch_fn
|
||||
)
|
||||
self.py_impl(FunctionalTensorMode)(functionalize_dispatch_mode_fn)
|
||||
self.py_impl(TransformType.Functionalize)(functionalize_functorch_fn)
|
||||
|
||||
return fn
|
||||
|
||||
|
|
@ -294,7 +309,7 @@ class HigherOrderOperator(OperatorBase, abc.ABC):
|
|||
k: Union[
|
||||
Type[TorchDispatchMode],
|
||||
Type[torch.Tensor],
|
||||
torch._C._functorch.TransformType,
|
||||
TransformType,
|
||||
DispatchKey,
|
||||
],
|
||||
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
||||
|
|
@ -408,7 +423,7 @@ class HigherOrderOperator(OperatorBase, abc.ABC):
|
|||
curr_mode = _get_current_dispatch_mode_pre_dispatch()
|
||||
assert (
|
||||
curr_mode is not None
|
||||
), "Illegal invocation of dispatch on torch._C.DispatchKey.PreDispatch without a mode."
|
||||
), "Illegal invocation of dispatch on DispatchKey.PreDispatch without a mode."
|
||||
assert (
|
||||
type(curr_mode) in self.python_key_table
|
||||
), f"Current active mode {curr_mode} not registered"
|
||||
|
|
@ -817,7 +832,7 @@ class OpOverload(OperatorBase):
|
|||
curr_mode = type(_get_current_dispatch_mode())
|
||||
assert (
|
||||
curr_mode is not None
|
||||
), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode."
|
||||
), "Illegal invocation of dispatch on DispatchKey.Python without a mode."
|
||||
|
||||
if curr_mode not in self.python_key_table:
|
||||
if isinstance(self, TorchBindOpOverload):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import Any, Callable, List, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
|
||||
|
|
@ -21,7 +22,8 @@ __all__ = [
|
|||
]
|
||||
|
||||
|
||||
_F = TypeVar("_F", bound=Callable[..., Any])
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
|
||||
def compile(*args, **kwargs):
|
||||
|
|
@ -124,11 +126,11 @@ def allow_in_graph(fn):
|
|||
|
||||
|
||||
def substitute_in_graph(
|
||||
original_fn: _F,
|
||||
original_fn: Callable[_P, _R],
|
||||
*,
|
||||
can_constant_fold_through: bool = False,
|
||||
skip_signature_check: bool = False,
|
||||
) -> Callable[[_F], _F]:
|
||||
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
|
||||
"""
|
||||
Register a polyfill handler for a function, usually a C function from the C extension, to be
|
||||
used in place of the original function when inlining the original function in the graph.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user