diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index cd9a7209bf1..41cb2117efa 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -5,6 +5,7 @@ import inspect import sys from dataclasses import dataclass from typing import Any, Callable, Dict, Type, TYPE_CHECKING, TypeVar +from typing_extensions import ParamSpec import torch from torch._environment import is_fbcode @@ -45,7 +46,8 @@ else: globals()[name] = getattr(torch._C._dynamo.eval_frame, name) -_F = TypeVar("_F", bound=Callable[..., Any]) +_P = ParamSpec("_P") +_R = TypeVar("_R") def run(fn=None): @@ -219,13 +221,13 @@ def forbid_in_graph(fn): def substitute_in_graph( - original_fn: _F, + original_fn: Callable[_P, _R], *, can_constant_fold_through: bool = False, skip_signature_check: bool = False, # type that is embedded in the Python interpreter is_embedded_type: bool = False, # internal use only -) -> Callable[[_F], _F]: +) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: """ Register a polyfill handler for a function, usually a C function from the C extension, to be used in place of the original function when inlining the original function in the graph. @@ -291,7 +293,7 @@ def substitute_in_graph( if id(original_fn) in ITERTOOLS_TYPE_IDS: ITERTOOLS_POLYFILLED_TYPE_IDS.add(id(original_fn)) - def wrapper(traceable_fn: _F) -> _F: + def wrapper(traceable_fn: Callable[_P, _R]) -> Callable[_P, _R]: if not is_function(traceable_fn): raise TypeError( f"@substitute_in_graph(...) expects a function but got {type(traceable_fn)!r}" @@ -382,10 +384,10 @@ def substitute_in_graph( # Need to wrap the function because we may cannot assign __torch_dynamo_polyfill__ to a # C++ function. @functools.wraps(traceable_fn) - def wrapped(*args, **kwargs): + def wrapped(*args: _P.args, **kwargs: _P.kwargs) -> _R: return original_fn(*args, **kwargs) - def dispatch_fn(self, value: _F) -> PolyfilledFunctionVariable: + def dispatch_fn(self, value: Callable[_P, _R]) -> PolyfilledFunctionVariable: return PolyfilledFunctionVariable( value, source=self.source, diff --git a/torch/_dynamo/polyfills/itertools.py b/torch/_dynamo/polyfills/itertools.py index 368f73ae90e..8a43b2af043 100644 --- a/torch/_dynamo/polyfills/itertools.py +++ b/torch/_dynamo/polyfills/itertools.py @@ -42,7 +42,7 @@ def chain_from_iterable(iterable: Iterable[Iterable[_T]], /) -> Iterator[_T]: return itertools.chain(*iterable) -chain.from_iterable = chain_from_iterable # type: ignore[method-assign] +chain.from_iterable = chain_from_iterable # type: ignore[attr-defined] # Reference: https://docs.python.org/3/library/itertools.html#itertools.compress diff --git a/torch/_ops.py b/torch/_ops.py index 4f39b789e47..9058be5400d 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -6,24 +6,36 @@ import importlib import inspect import sys import types -from typing import Any, Callable, Dict, List, Set, Type, TypeVar, Union -from typing_extensions import ParamSpec +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Set, + Type, + TYPE_CHECKING, + TypeVar, + Union, +) +from typing_extensions import Concatenate, ParamSpec import torch import torch.utils._pytree as pytree from torch import _utils_internal from torch._C import _dispatch_is_included_in_alias as is_included_in_alias, DispatchKey -from torch._functorch.pyfunctorch import dispatch_functorch +from torch._functorch.pyfunctorch import dispatch_functorch, TransformType from torch.utils._python_dispatch import TorchDispatchMode +if TYPE_CHECKING: + from torch._subclasses.functional_tensor import BaseFunctionalizeAPI + + _T = TypeVar("_T") _P = ParamSpec("_P") -_F = TypeVar("_F", bound=Callable[..., Any]) - - # Query `hasattr` only once. _SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags") @@ -112,11 +124,11 @@ class OperatorBase: k: Union[ Type[TorchDispatchMode], Type[torch.Tensor], - torch._C._functorch.TransformType, - torch._C.DispatchKey, + TransformType, + DispatchKey, ], ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: - def inner(fn: _F) -> _F: + def inner(fn: Callable[_P, _T]) -> Callable[_P, _T]: if inspect.isclass(k) and ( issubclass(k, TorchDispatchMode) or issubclass(k, torch.Tensor) ): @@ -126,7 +138,7 @@ class OperatorBase: self._dispatch_cache.clear() return fn - if isinstance(k, torch._C._functorch.TransformType): + if isinstance(k, TransformType): assert k not in self.functorch_table self.functorch_table[k] = fn return fn @@ -134,7 +146,7 @@ class OperatorBase: assert isinstance(k, DispatchKey) assert ( k != DispatchKey.Python - ), "Please register a mode for the torch._C.DispatchKey.Python key instead." + ), "Please register a mode for the DispatchKey.Python key instead." if k in self.py_kernels: raise RuntimeError( @@ -157,31 +169,34 @@ class OperatorBase: # with ctx.redispatch_to_next(): # out = ctx.functionalize(inner_f)(*args_unwrapped) # return ctx.wrap_tensors(out) - def py_functionalize_impl(self, fn: _F) -> _F: + def py_functionalize_impl( + self, fn: Callable[Concatenate["BaseFunctionalizeAPI", _P], _T] + ) -> Callable[Concatenate["BaseFunctionalizeAPI", _P], _T]: from torch._subclasses.functional_tensor import ( - CppFunctionalizeAPI as _CppFunctionalizeAPI, - FunctorchFunctionalizeAPI as _FunctorchFunctionalizeAPI, - PythonFunctionalizeAPI as _PythonFunctionalizeAPI, + CppFunctionalizeAPI, + FunctionalTensorMode, + FunctorchFunctionalizeAPI, + PythonFunctionalizeAPI, ) # Construct our three flavors of functionalization, # each of which have slightly different wrap/unwrap/redispatch policies - def functionalize_dk_fn(*args, **kwargs): - return fn(_CppFunctionalizeAPI(), *args, **kwargs) + def functionalize_dk_fn(*args: _P.args, **kwargs: _P.kwargs) -> _T: + return fn(CppFunctionalizeAPI(), *args, **kwargs) - def functionalize_dispatch_mode_fn(mode, *args, **kwargs): - return fn(_PythonFunctionalizeAPI(mode), *args, **kwargs) + def functionalize_dispatch_mode_fn( + mode: Optional[FunctionalTensorMode], *args: _P.args, **kwargs: _P.kwargs + ) -> _T: + return fn(PythonFunctionalizeAPI(mode), *args, **kwargs) - def functionalize_functorch_fn(interpreter, *args, **kwargs): - return fn(_FunctorchFunctionalizeAPI(interpreter), *args, **kwargs) + def functionalize_functorch_fn( + interpreter, *args: _P.args, **kwargs: _P.kwargs + ) -> _T: + return fn(FunctorchFunctionalizeAPI(interpreter), *args, **kwargs) self.py_impl(DispatchKey.Functionalize)(functionalize_dk_fn) - self.py_impl(torch._subclasses.functional_tensor.FunctionalTensorMode)( - functionalize_dispatch_mode_fn - ) - self.py_impl(torch._C._functorch.TransformType.Functionalize)( - functionalize_functorch_fn - ) + self.py_impl(FunctionalTensorMode)(functionalize_dispatch_mode_fn) + self.py_impl(TransformType.Functionalize)(functionalize_functorch_fn) return fn @@ -294,7 +309,7 @@ class HigherOrderOperator(OperatorBase, abc.ABC): k: Union[ Type[TorchDispatchMode], Type[torch.Tensor], - torch._C._functorch.TransformType, + TransformType, DispatchKey, ], ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: @@ -408,7 +423,7 @@ class HigherOrderOperator(OperatorBase, abc.ABC): curr_mode = _get_current_dispatch_mode_pre_dispatch() assert ( curr_mode is not None - ), "Illegal invocation of dispatch on torch._C.DispatchKey.PreDispatch without a mode." + ), "Illegal invocation of dispatch on DispatchKey.PreDispatch without a mode." assert ( type(curr_mode) in self.python_key_table ), f"Current active mode {curr_mode} not registered" @@ -817,7 +832,7 @@ class OpOverload(OperatorBase): curr_mode = type(_get_current_dispatch_mode()) assert ( curr_mode is not None - ), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode." + ), "Illegal invocation of dispatch on DispatchKey.Python without a mode." if curr_mode not in self.python_key_table: if isinstance(self, TorchBindOpOverload): diff --git a/torch/compiler/__init__.py b/torch/compiler/__init__.py index 038f36d20e2..cb65bd34b87 100644 --- a/torch/compiler/__init__.py +++ b/torch/compiler/__init__.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs from typing import Any, Callable, List, TypeVar +from typing_extensions import ParamSpec import torch @@ -21,7 +22,8 @@ __all__ = [ ] -_F = TypeVar("_F", bound=Callable[..., Any]) +_P = ParamSpec("_P") +_R = TypeVar("_R") def compile(*args, **kwargs): @@ -124,11 +126,11 @@ def allow_in_graph(fn): def substitute_in_graph( - original_fn: _F, + original_fn: Callable[_P, _R], *, can_constant_fold_through: bool = False, skip_signature_check: bool = False, -) -> Callable[[_F], _F]: +) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: """ Register a polyfill handler for a function, usually a C function from the C extension, to be used in place of the original function when inlining the original function in the graph.