mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Improve typing of ONNX decorators with ParamSpec (#162332)
## Summary This PR improves typing in ONNX-related modules by replacing TypeVar bound to Callable[..., Any] with ParamSpec to preserve parameter types and avoid type erasure in decorator functions. ## Changes - `torch/onnx/_internal/exporter/_flags.py`: Replace TCallable TypeVar with ParamSpec - `torch/onnx/ops/_impl.py`: Replace _T TypeVar with ParamSpec for _onnx_op decorator - `torch/onnx/_internal/exporter/_torchlib/_torchlib_registry.py`: Replace _T TypeVar with ParamSpec ## Motivation The previous implementation used TypeVar bound to Callable which erased parameter type information to Any. ParamSpec preserves the exact parameter types and return types, providing better type safety and IDE support. ## Testing - Verified all changes compile and import correctly - Created comprehensive test suite to validate ParamSpec functionality - No linting errors introduced - Maintains backward compatibility Fixes #142306 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162332 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
7a83cf430e
commit
9ad5e8edb1
|
|
@ -3,17 +3,20 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from typing import Any, Callable, cast, TypeVar
|
||||
from typing import Callable, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
|
||||
_is_onnx_exporting = False
|
||||
|
||||
TCallable = TypeVar("TCallable", bound=Callable[..., Any])
|
||||
# Use ParamSpec to preserve parameter types instead of erasing to Any
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
|
||||
def set_onnx_exporting_flag(func: TCallable) -> TCallable:
|
||||
def set_onnx_exporting_flag(func: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
global _is_onnx_exporting
|
||||
_is_onnx_exporting = True
|
||||
try:
|
||||
|
|
@ -22,4 +25,4 @@ def set_onnx_exporting_flag(func: TCallable) -> TCallable:
|
|||
# Ensure it resets even if an exception occurs
|
||||
_is_onnx_exporting = False
|
||||
|
||||
return cast(TCallable, wrapper)
|
||||
return wrapper
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ __all__ = ["onnx_impl", "get_torchlib_ops"]
|
|||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import onnxscript
|
||||
|
||||
|
|
@ -17,7 +18,9 @@ import torch
|
|||
from torch.onnx._internal.exporter import _constants, _registration
|
||||
|
||||
|
||||
_T = TypeVar("_T", bound=Callable)
|
||||
# Use ParamSpec for better type preservation instead of bound Callable TypeVar
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
logger = logging.getLogger("__name__")
|
||||
|
||||
|
|
@ -33,7 +36,7 @@ def onnx_impl(
|
|||
opset_introduced: int = 18,
|
||||
no_compile: bool = False,
|
||||
private: bool = False,
|
||||
) -> Callable[[_T], _T]:
|
||||
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
|
||||
"""Register an ONNX implementation of a torch op."""
|
||||
|
||||
if isinstance(target, torch._ops.OpOverloadPacket):
|
||||
|
|
@ -44,8 +47,8 @@ def onnx_impl(
|
|||
)
|
||||
|
||||
def wrapper(
|
||||
func: _T,
|
||||
) -> _T:
|
||||
func: Callable[_P, _R],
|
||||
) -> Callable[_P, _R]:
|
||||
processed_func: Any
|
||||
if no_compile:
|
||||
processed_func = func
|
||||
|
|
|
|||
|
|
@ -1,13 +1,15 @@
|
|||
# flake8: noqa: B950
|
||||
import math
|
||||
import typing
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, Optional, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
from torch.onnx.ops import _dtype_mappings
|
||||
|
||||
|
||||
_T = typing.TypeVar("_T", bound=Callable)
|
||||
# Use ParamSpec for better type preservation instead of bound Callable TypeVar
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
# ONNX to ATen decomp table
|
||||
ONNX_ATEN_DECOMP_TABLE: dict[torch._ops.OpOverload, Callable] = {}
|
||||
|
|
@ -21,10 +23,12 @@ _ATTENTION_23_ALLOWED_INTERMEDIATE_PRECISIONS = frozenset(
|
|||
)
|
||||
|
||||
|
||||
def _onnx_op(op_type: str, opset_version: int) -> Callable[[_T], _T]:
|
||||
def _onnx_op(
|
||||
op_type: str, opset_version: int
|
||||
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
|
||||
"""Decorator to register an ONNX operator with a custom implementation."""
|
||||
|
||||
def decorator(func: _T) -> _T:
|
||||
def decorator(func: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
overload = f"opset{opset_version}"
|
||||
torch_op = torch.library.custom_op(
|
||||
f"onnx::{op_type}.{overload}", mutates_args=()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user