Propagate callable parameter types using ParamSpec (#142306) (#144047)

Fixes #142306

This PR includes typing improvements and refactoring for the following files:
- __init__.py
- decorators.py
- _ops.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144047
Approved by: https://github.com/XuehaiPan, https://github.com/Skylion007

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
Co-authored-by: Xuehai Pan <XuehaiPan@pku.edu.cn>
This commit is contained in:
yijun-lee 2025-01-06 16:16:16 +00:00 committed by PyTorch MergeBot
parent 9225f149eb
commit d4609af1ca
4 changed files with 59 additions and 40 deletions

View File

@ -5,6 +5,7 @@ import inspect
import sys import sys
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Dict, Type, TYPE_CHECKING, TypeVar from typing import Any, Callable, Dict, Type, TYPE_CHECKING, TypeVar
from typing_extensions import ParamSpec
import torch import torch
from torch._environment import is_fbcode from torch._environment import is_fbcode
@ -45,7 +46,8 @@ else:
globals()[name] = getattr(torch._C._dynamo.eval_frame, name) globals()[name] = getattr(torch._C._dynamo.eval_frame, name)
_F = TypeVar("_F", bound=Callable[..., Any]) _P = ParamSpec("_P")
_R = TypeVar("_R")
def run(fn=None): def run(fn=None):
@ -219,13 +221,13 @@ def forbid_in_graph(fn):
def substitute_in_graph( def substitute_in_graph(
original_fn: _F, original_fn: Callable[_P, _R],
*, *,
can_constant_fold_through: bool = False, can_constant_fold_through: bool = False,
skip_signature_check: bool = False, skip_signature_check: bool = False,
# type that is embedded in the Python interpreter # type that is embedded in the Python interpreter
is_embedded_type: bool = False, # internal use only is_embedded_type: bool = False, # internal use only
) -> Callable[[_F], _F]: ) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
""" """
Register a polyfill handler for a function, usually a C function from the C extension, to be Register a polyfill handler for a function, usually a C function from the C extension, to be
used in place of the original function when inlining the original function in the graph. used in place of the original function when inlining the original function in the graph.
@ -291,7 +293,7 @@ def substitute_in_graph(
if id(original_fn) in ITERTOOLS_TYPE_IDS: if id(original_fn) in ITERTOOLS_TYPE_IDS:
ITERTOOLS_POLYFILLED_TYPE_IDS.add(id(original_fn)) ITERTOOLS_POLYFILLED_TYPE_IDS.add(id(original_fn))
def wrapper(traceable_fn: _F) -> _F: def wrapper(traceable_fn: Callable[_P, _R]) -> Callable[_P, _R]:
if not is_function(traceable_fn): if not is_function(traceable_fn):
raise TypeError( raise TypeError(
f"@substitute_in_graph(...) expects a function but got {type(traceable_fn)!r}" f"@substitute_in_graph(...) expects a function but got {type(traceable_fn)!r}"
@ -382,10 +384,10 @@ def substitute_in_graph(
# Need to wrap the function because we may cannot assign __torch_dynamo_polyfill__ to a # Need to wrap the function because we may cannot assign __torch_dynamo_polyfill__ to a
# C++ function. # C++ function.
@functools.wraps(traceable_fn) @functools.wraps(traceable_fn)
def wrapped(*args, **kwargs): def wrapped(*args: _P.args, **kwargs: _P.kwargs) -> _R:
return original_fn(*args, **kwargs) return original_fn(*args, **kwargs)
def dispatch_fn(self, value: _F) -> PolyfilledFunctionVariable: def dispatch_fn(self, value: Callable[_P, _R]) -> PolyfilledFunctionVariable:
return PolyfilledFunctionVariable( return PolyfilledFunctionVariable(
value, value,
source=self.source, source=self.source,

View File

@ -42,7 +42,7 @@ def chain_from_iterable(iterable: Iterable[Iterable[_T]], /) -> Iterator[_T]:
return itertools.chain(*iterable) return itertools.chain(*iterable)
chain.from_iterable = chain_from_iterable # type: ignore[method-assign] chain.from_iterable = chain_from_iterable # type: ignore[attr-defined]
# Reference: https://docs.python.org/3/library/itertools.html#itertools.compress # Reference: https://docs.python.org/3/library/itertools.html#itertools.compress

View File

@ -6,24 +6,36 @@ import importlib
import inspect import inspect
import sys import sys
import types import types
from typing import Any, Callable, Dict, List, Set, Type, TypeVar, Union from typing import (
from typing_extensions import ParamSpec Any,
Callable,
Dict,
List,
Optional,
Set,
Type,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing_extensions import Concatenate, ParamSpec
import torch import torch
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
from torch import _utils_internal from torch import _utils_internal
from torch._C import _dispatch_is_included_in_alias as is_included_in_alias, DispatchKey from torch._C import _dispatch_is_included_in_alias as is_included_in_alias, DispatchKey
from torch._functorch.pyfunctorch import dispatch_functorch from torch._functorch.pyfunctorch import dispatch_functorch, TransformType
from torch.utils._python_dispatch import TorchDispatchMode from torch.utils._python_dispatch import TorchDispatchMode
if TYPE_CHECKING:
from torch._subclasses.functional_tensor import BaseFunctionalizeAPI
_T = TypeVar("_T") _T = TypeVar("_T")
_P = ParamSpec("_P") _P = ParamSpec("_P")
_F = TypeVar("_F", bound=Callable[..., Any])
# Query `hasattr` only once. # Query `hasattr` only once.
_SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags") _SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags")
@ -112,11 +124,11 @@ class OperatorBase:
k: Union[ k: Union[
Type[TorchDispatchMode], Type[TorchDispatchMode],
Type[torch.Tensor], Type[torch.Tensor],
torch._C._functorch.TransformType, TransformType,
torch._C.DispatchKey, DispatchKey,
], ],
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
def inner(fn: _F) -> _F: def inner(fn: Callable[_P, _T]) -> Callable[_P, _T]:
if inspect.isclass(k) and ( if inspect.isclass(k) and (
issubclass(k, TorchDispatchMode) or issubclass(k, torch.Tensor) issubclass(k, TorchDispatchMode) or issubclass(k, torch.Tensor)
): ):
@ -126,7 +138,7 @@ class OperatorBase:
self._dispatch_cache.clear() self._dispatch_cache.clear()
return fn return fn
if isinstance(k, torch._C._functorch.TransformType): if isinstance(k, TransformType):
assert k not in self.functorch_table assert k not in self.functorch_table
self.functorch_table[k] = fn self.functorch_table[k] = fn
return fn return fn
@ -134,7 +146,7 @@ class OperatorBase:
assert isinstance(k, DispatchKey) assert isinstance(k, DispatchKey)
assert ( assert (
k != DispatchKey.Python k != DispatchKey.Python
), "Please register a mode for the torch._C.DispatchKey.Python key instead." ), "Please register a mode for the DispatchKey.Python key instead."
if k in self.py_kernels: if k in self.py_kernels:
raise RuntimeError( raise RuntimeError(
@ -157,31 +169,34 @@ class OperatorBase:
# with ctx.redispatch_to_next(): # with ctx.redispatch_to_next():
# out = ctx.functionalize(inner_f)(*args_unwrapped) # out = ctx.functionalize(inner_f)(*args_unwrapped)
# return ctx.wrap_tensors(out) # return ctx.wrap_tensors(out)
def py_functionalize_impl(self, fn: _F) -> _F: def py_functionalize_impl(
self, fn: Callable[Concatenate["BaseFunctionalizeAPI", _P], _T]
) -> Callable[Concatenate["BaseFunctionalizeAPI", _P], _T]:
from torch._subclasses.functional_tensor import ( from torch._subclasses.functional_tensor import (
CppFunctionalizeAPI as _CppFunctionalizeAPI, CppFunctionalizeAPI,
FunctorchFunctionalizeAPI as _FunctorchFunctionalizeAPI, FunctionalTensorMode,
PythonFunctionalizeAPI as _PythonFunctionalizeAPI, FunctorchFunctionalizeAPI,
PythonFunctionalizeAPI,
) )
# Construct our three flavors of functionalization, # Construct our three flavors of functionalization,
# each of which have slightly different wrap/unwrap/redispatch policies # each of which have slightly different wrap/unwrap/redispatch policies
def functionalize_dk_fn(*args, **kwargs): def functionalize_dk_fn(*args: _P.args, **kwargs: _P.kwargs) -> _T:
return fn(_CppFunctionalizeAPI(), *args, **kwargs) return fn(CppFunctionalizeAPI(), *args, **kwargs)
def functionalize_dispatch_mode_fn(mode, *args, **kwargs): def functionalize_dispatch_mode_fn(
return fn(_PythonFunctionalizeAPI(mode), *args, **kwargs) mode: Optional[FunctionalTensorMode], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
return fn(PythonFunctionalizeAPI(mode), *args, **kwargs)
def functionalize_functorch_fn(interpreter, *args, **kwargs): def functionalize_functorch_fn(
return fn(_FunctorchFunctionalizeAPI(interpreter), *args, **kwargs) interpreter, *args: _P.args, **kwargs: _P.kwargs
) -> _T:
return fn(FunctorchFunctionalizeAPI(interpreter), *args, **kwargs)
self.py_impl(DispatchKey.Functionalize)(functionalize_dk_fn) self.py_impl(DispatchKey.Functionalize)(functionalize_dk_fn)
self.py_impl(torch._subclasses.functional_tensor.FunctionalTensorMode)( self.py_impl(FunctionalTensorMode)(functionalize_dispatch_mode_fn)
functionalize_dispatch_mode_fn self.py_impl(TransformType.Functionalize)(functionalize_functorch_fn)
)
self.py_impl(torch._C._functorch.TransformType.Functionalize)(
functionalize_functorch_fn
)
return fn return fn
@ -294,7 +309,7 @@ class HigherOrderOperator(OperatorBase, abc.ABC):
k: Union[ k: Union[
Type[TorchDispatchMode], Type[TorchDispatchMode],
Type[torch.Tensor], Type[torch.Tensor],
torch._C._functorch.TransformType, TransformType,
DispatchKey, DispatchKey,
], ],
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
@ -408,7 +423,7 @@ class HigherOrderOperator(OperatorBase, abc.ABC):
curr_mode = _get_current_dispatch_mode_pre_dispatch() curr_mode = _get_current_dispatch_mode_pre_dispatch()
assert ( assert (
curr_mode is not None curr_mode is not None
), "Illegal invocation of dispatch on torch._C.DispatchKey.PreDispatch without a mode." ), "Illegal invocation of dispatch on DispatchKey.PreDispatch without a mode."
assert ( assert (
type(curr_mode) in self.python_key_table type(curr_mode) in self.python_key_table
), f"Current active mode {curr_mode} not registered" ), f"Current active mode {curr_mode} not registered"
@ -817,7 +832,7 @@ class OpOverload(OperatorBase):
curr_mode = type(_get_current_dispatch_mode()) curr_mode = type(_get_current_dispatch_mode())
assert ( assert (
curr_mode is not None curr_mode is not None
), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode." ), "Illegal invocation of dispatch on DispatchKey.Python without a mode."
if curr_mode not in self.python_key_table: if curr_mode not in self.python_key_table:
if isinstance(self, TorchBindOpOverload): if isinstance(self, TorchBindOpOverload):

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from typing import Any, Callable, List, TypeVar from typing import Any, Callable, List, TypeVar
from typing_extensions import ParamSpec
import torch import torch
@ -21,7 +22,8 @@ __all__ = [
] ]
_F = TypeVar("_F", bound=Callable[..., Any]) _P = ParamSpec("_P")
_R = TypeVar("_R")
def compile(*args, **kwargs): def compile(*args, **kwargs):
@ -124,11 +126,11 @@ def allow_in_graph(fn):
def substitute_in_graph( def substitute_in_graph(
original_fn: _F, original_fn: Callable[_P, _R],
*, *,
can_constant_fold_through: bool = False, can_constant_fold_through: bool = False,
skip_signature_check: bool = False, skip_signature_check: bool = False,
) -> Callable[[_F], _F]: ) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
""" """
Register a polyfill handler for a function, usually a C function from the C extension, to be Register a polyfill handler for a function, usually a C function from the C extension, to be
used in place of the original function when inlining the original function in the graph. used in place of the original function when inlining the original function in the graph.