mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fixes #142306 This PR includes typing improvements and refactoring for the following files: - __init__.py - decorators.py - _ops.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/144047 Approved by: https://github.com/XuehaiPan, https://github.com/Skylion007 Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com> Co-authored-by: Xuehai Pan <XuehaiPan@pku.edu.cn>
This commit is contained in:
parent
9225f149eb
commit
d4609af1ca
|
|
@ -5,6 +5,7 @@ import inspect
|
||||||
import sys
|
import sys
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Dict, Type, TYPE_CHECKING, TypeVar
|
from typing import Any, Callable, Dict, Type, TYPE_CHECKING, TypeVar
|
||||||
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._environment import is_fbcode
|
from torch._environment import is_fbcode
|
||||||
|
|
@ -45,7 +46,8 @@ else:
|
||||||
globals()[name] = getattr(torch._C._dynamo.eval_frame, name)
|
globals()[name] = getattr(torch._C._dynamo.eval_frame, name)
|
||||||
|
|
||||||
|
|
||||||
_F = TypeVar("_F", bound=Callable[..., Any])
|
_P = ParamSpec("_P")
|
||||||
|
_R = TypeVar("_R")
|
||||||
|
|
||||||
|
|
||||||
def run(fn=None):
|
def run(fn=None):
|
||||||
|
|
@ -219,13 +221,13 @@ def forbid_in_graph(fn):
|
||||||
|
|
||||||
|
|
||||||
def substitute_in_graph(
|
def substitute_in_graph(
|
||||||
original_fn: _F,
|
original_fn: Callable[_P, _R],
|
||||||
*,
|
*,
|
||||||
can_constant_fold_through: bool = False,
|
can_constant_fold_through: bool = False,
|
||||||
skip_signature_check: bool = False,
|
skip_signature_check: bool = False,
|
||||||
# type that is embedded in the Python interpreter
|
# type that is embedded in the Python interpreter
|
||||||
is_embedded_type: bool = False, # internal use only
|
is_embedded_type: bool = False, # internal use only
|
||||||
) -> Callable[[_F], _F]:
|
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
|
||||||
"""
|
"""
|
||||||
Register a polyfill handler for a function, usually a C function from the C extension, to be
|
Register a polyfill handler for a function, usually a C function from the C extension, to be
|
||||||
used in place of the original function when inlining the original function in the graph.
|
used in place of the original function when inlining the original function in the graph.
|
||||||
|
|
@ -291,7 +293,7 @@ def substitute_in_graph(
|
||||||
if id(original_fn) in ITERTOOLS_TYPE_IDS:
|
if id(original_fn) in ITERTOOLS_TYPE_IDS:
|
||||||
ITERTOOLS_POLYFILLED_TYPE_IDS.add(id(original_fn))
|
ITERTOOLS_POLYFILLED_TYPE_IDS.add(id(original_fn))
|
||||||
|
|
||||||
def wrapper(traceable_fn: _F) -> _F:
|
def wrapper(traceable_fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||||
if not is_function(traceable_fn):
|
if not is_function(traceable_fn):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"@substitute_in_graph(...) expects a function but got {type(traceable_fn)!r}"
|
f"@substitute_in_graph(...) expects a function but got {type(traceable_fn)!r}"
|
||||||
|
|
@ -382,10 +384,10 @@ def substitute_in_graph(
|
||||||
# Need to wrap the function because we may cannot assign __torch_dynamo_polyfill__ to a
|
# Need to wrap the function because we may cannot assign __torch_dynamo_polyfill__ to a
|
||||||
# C++ function.
|
# C++ function.
|
||||||
@functools.wraps(traceable_fn)
|
@functools.wraps(traceable_fn)
|
||||||
def wrapped(*args, **kwargs):
|
def wrapped(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||||
return original_fn(*args, **kwargs)
|
return original_fn(*args, **kwargs)
|
||||||
|
|
||||||
def dispatch_fn(self, value: _F) -> PolyfilledFunctionVariable:
|
def dispatch_fn(self, value: Callable[_P, _R]) -> PolyfilledFunctionVariable:
|
||||||
return PolyfilledFunctionVariable(
|
return PolyfilledFunctionVariable(
|
||||||
value,
|
value,
|
||||||
source=self.source,
|
source=self.source,
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ def chain_from_iterable(iterable: Iterable[Iterable[_T]], /) -> Iterator[_T]:
|
||||||
return itertools.chain(*iterable)
|
return itertools.chain(*iterable)
|
||||||
|
|
||||||
|
|
||||||
chain.from_iterable = chain_from_iterable # type: ignore[method-assign]
|
chain.from_iterable = chain_from_iterable # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
# Reference: https://docs.python.org/3/library/itertools.html#itertools.compress
|
# Reference: https://docs.python.org/3/library/itertools.html#itertools.compress
|
||||||
|
|
|
||||||
|
|
@ -6,24 +6,36 @@ import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import sys
|
import sys
|
||||||
import types
|
import types
|
||||||
from typing import Any, Callable, Dict, List, Set, Type, TypeVar, Union
|
from typing import (
|
||||||
from typing_extensions import ParamSpec
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Set,
|
||||||
|
Type,
|
||||||
|
TYPE_CHECKING,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
from typing_extensions import Concatenate, ParamSpec
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
from torch import _utils_internal
|
from torch import _utils_internal
|
||||||
from torch._C import _dispatch_is_included_in_alias as is_included_in_alias, DispatchKey
|
from torch._C import _dispatch_is_included_in_alias as is_included_in_alias, DispatchKey
|
||||||
from torch._functorch.pyfunctorch import dispatch_functorch
|
from torch._functorch.pyfunctorch import dispatch_functorch, TransformType
|
||||||
from torch.utils._python_dispatch import TorchDispatchMode
|
from torch.utils._python_dispatch import TorchDispatchMode
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from torch._subclasses.functional_tensor import BaseFunctionalizeAPI
|
||||||
|
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
_T = TypeVar("_T")
|
||||||
_P = ParamSpec("_P")
|
_P = ParamSpec("_P")
|
||||||
|
|
||||||
|
|
||||||
_F = TypeVar("_F", bound=Callable[..., Any])
|
|
||||||
|
|
||||||
|
|
||||||
# Query `hasattr` only once.
|
# Query `hasattr` only once.
|
||||||
_SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags")
|
_SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags")
|
||||||
|
|
||||||
|
|
@ -112,11 +124,11 @@ class OperatorBase:
|
||||||
k: Union[
|
k: Union[
|
||||||
Type[TorchDispatchMode],
|
Type[TorchDispatchMode],
|
||||||
Type[torch.Tensor],
|
Type[torch.Tensor],
|
||||||
torch._C._functorch.TransformType,
|
TransformType,
|
||||||
torch._C.DispatchKey,
|
DispatchKey,
|
||||||
],
|
],
|
||||||
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
||||||
def inner(fn: _F) -> _F:
|
def inner(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||||
if inspect.isclass(k) and (
|
if inspect.isclass(k) and (
|
||||||
issubclass(k, TorchDispatchMode) or issubclass(k, torch.Tensor)
|
issubclass(k, TorchDispatchMode) or issubclass(k, torch.Tensor)
|
||||||
):
|
):
|
||||||
|
|
@ -126,7 +138,7 @@ class OperatorBase:
|
||||||
self._dispatch_cache.clear()
|
self._dispatch_cache.clear()
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
if isinstance(k, torch._C._functorch.TransformType):
|
if isinstance(k, TransformType):
|
||||||
assert k not in self.functorch_table
|
assert k not in self.functorch_table
|
||||||
self.functorch_table[k] = fn
|
self.functorch_table[k] = fn
|
||||||
return fn
|
return fn
|
||||||
|
|
@ -134,7 +146,7 @@ class OperatorBase:
|
||||||
assert isinstance(k, DispatchKey)
|
assert isinstance(k, DispatchKey)
|
||||||
assert (
|
assert (
|
||||||
k != DispatchKey.Python
|
k != DispatchKey.Python
|
||||||
), "Please register a mode for the torch._C.DispatchKey.Python key instead."
|
), "Please register a mode for the DispatchKey.Python key instead."
|
||||||
|
|
||||||
if k in self.py_kernels:
|
if k in self.py_kernels:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
|
@ -157,31 +169,34 @@ class OperatorBase:
|
||||||
# with ctx.redispatch_to_next():
|
# with ctx.redispatch_to_next():
|
||||||
# out = ctx.functionalize(inner_f)(*args_unwrapped)
|
# out = ctx.functionalize(inner_f)(*args_unwrapped)
|
||||||
# return ctx.wrap_tensors(out)
|
# return ctx.wrap_tensors(out)
|
||||||
def py_functionalize_impl(self, fn: _F) -> _F:
|
def py_functionalize_impl(
|
||||||
|
self, fn: Callable[Concatenate["BaseFunctionalizeAPI", _P], _T]
|
||||||
|
) -> Callable[Concatenate["BaseFunctionalizeAPI", _P], _T]:
|
||||||
from torch._subclasses.functional_tensor import (
|
from torch._subclasses.functional_tensor import (
|
||||||
CppFunctionalizeAPI as _CppFunctionalizeAPI,
|
CppFunctionalizeAPI,
|
||||||
FunctorchFunctionalizeAPI as _FunctorchFunctionalizeAPI,
|
FunctionalTensorMode,
|
||||||
PythonFunctionalizeAPI as _PythonFunctionalizeAPI,
|
FunctorchFunctionalizeAPI,
|
||||||
|
PythonFunctionalizeAPI,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Construct our three flavors of functionalization,
|
# Construct our three flavors of functionalization,
|
||||||
# each of which have slightly different wrap/unwrap/redispatch policies
|
# each of which have slightly different wrap/unwrap/redispatch policies
|
||||||
def functionalize_dk_fn(*args, **kwargs):
|
def functionalize_dk_fn(*args: _P.args, **kwargs: _P.kwargs) -> _T:
|
||||||
return fn(_CppFunctionalizeAPI(), *args, **kwargs)
|
return fn(CppFunctionalizeAPI(), *args, **kwargs)
|
||||||
|
|
||||||
def functionalize_dispatch_mode_fn(mode, *args, **kwargs):
|
def functionalize_dispatch_mode_fn(
|
||||||
return fn(_PythonFunctionalizeAPI(mode), *args, **kwargs)
|
mode: Optional[FunctionalTensorMode], *args: _P.args, **kwargs: _P.kwargs
|
||||||
|
) -> _T:
|
||||||
|
return fn(PythonFunctionalizeAPI(mode), *args, **kwargs)
|
||||||
|
|
||||||
def functionalize_functorch_fn(interpreter, *args, **kwargs):
|
def functionalize_functorch_fn(
|
||||||
return fn(_FunctorchFunctionalizeAPI(interpreter), *args, **kwargs)
|
interpreter, *args: _P.args, **kwargs: _P.kwargs
|
||||||
|
) -> _T:
|
||||||
|
return fn(FunctorchFunctionalizeAPI(interpreter), *args, **kwargs)
|
||||||
|
|
||||||
self.py_impl(DispatchKey.Functionalize)(functionalize_dk_fn)
|
self.py_impl(DispatchKey.Functionalize)(functionalize_dk_fn)
|
||||||
self.py_impl(torch._subclasses.functional_tensor.FunctionalTensorMode)(
|
self.py_impl(FunctionalTensorMode)(functionalize_dispatch_mode_fn)
|
||||||
functionalize_dispatch_mode_fn
|
self.py_impl(TransformType.Functionalize)(functionalize_functorch_fn)
|
||||||
)
|
|
||||||
self.py_impl(torch._C._functorch.TransformType.Functionalize)(
|
|
||||||
functionalize_functorch_fn
|
|
||||||
)
|
|
||||||
|
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
@ -294,7 +309,7 @@ class HigherOrderOperator(OperatorBase, abc.ABC):
|
||||||
k: Union[
|
k: Union[
|
||||||
Type[TorchDispatchMode],
|
Type[TorchDispatchMode],
|
||||||
Type[torch.Tensor],
|
Type[torch.Tensor],
|
||||||
torch._C._functorch.TransformType,
|
TransformType,
|
||||||
DispatchKey,
|
DispatchKey,
|
||||||
],
|
],
|
||||||
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
||||||
|
|
@ -408,7 +423,7 @@ class HigherOrderOperator(OperatorBase, abc.ABC):
|
||||||
curr_mode = _get_current_dispatch_mode_pre_dispatch()
|
curr_mode = _get_current_dispatch_mode_pre_dispatch()
|
||||||
assert (
|
assert (
|
||||||
curr_mode is not None
|
curr_mode is not None
|
||||||
), "Illegal invocation of dispatch on torch._C.DispatchKey.PreDispatch without a mode."
|
), "Illegal invocation of dispatch on DispatchKey.PreDispatch without a mode."
|
||||||
assert (
|
assert (
|
||||||
type(curr_mode) in self.python_key_table
|
type(curr_mode) in self.python_key_table
|
||||||
), f"Current active mode {curr_mode} not registered"
|
), f"Current active mode {curr_mode} not registered"
|
||||||
|
|
@ -817,7 +832,7 @@ class OpOverload(OperatorBase):
|
||||||
curr_mode = type(_get_current_dispatch_mode())
|
curr_mode = type(_get_current_dispatch_mode())
|
||||||
assert (
|
assert (
|
||||||
curr_mode is not None
|
curr_mode is not None
|
||||||
), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode."
|
), "Illegal invocation of dispatch on DispatchKey.Python without a mode."
|
||||||
|
|
||||||
if curr_mode not in self.python_key_table:
|
if curr_mode not in self.python_key_table:
|
||||||
if isinstance(self, TorchBindOpOverload):
|
if isinstance(self, TorchBindOpOverload):
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
from typing import Any, Callable, List, TypeVar
|
from typing import Any, Callable, List, TypeVar
|
||||||
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
@ -21,7 +22,8 @@ __all__ = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
_F = TypeVar("_F", bound=Callable[..., Any])
|
_P = ParamSpec("_P")
|
||||||
|
_R = TypeVar("_R")
|
||||||
|
|
||||||
|
|
||||||
def compile(*args, **kwargs):
|
def compile(*args, **kwargs):
|
||||||
|
|
@ -124,11 +126,11 @@ def allow_in_graph(fn):
|
||||||
|
|
||||||
|
|
||||||
def substitute_in_graph(
|
def substitute_in_graph(
|
||||||
original_fn: _F,
|
original_fn: Callable[_P, _R],
|
||||||
*,
|
*,
|
||||||
can_constant_fold_through: bool = False,
|
can_constant_fold_through: bool = False,
|
||||||
skip_signature_check: bool = False,
|
skip_signature_check: bool = False,
|
||||||
) -> Callable[[_F], _F]:
|
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
|
||||||
"""
|
"""
|
||||||
Register a polyfill handler for a function, usually a C function from the C extension, to be
|
Register a polyfill handler for a function, usually a C function from the C extension, to be
|
||||||
used in place of the original function when inlining the original function in the graph.
|
used in place of the original function when inlining the original function in the graph.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user