Improve typing of ONNX decorators with ParamSpec (#162332)

## Summary
This PR improves typing in ONNX-related modules by replacing TypeVar bound to Callable[..., Any] with ParamSpec to preserve parameter types and avoid type erasure in decorator functions.

## Changes
- `torch/onnx/_internal/exporter/_flags.py`: Replace TCallable TypeVar with ParamSpec
- `torch/onnx/ops/_impl.py`: Replace _T TypeVar with ParamSpec for _onnx_op decorator
- `torch/onnx/_internal/exporter/_torchlib/_torchlib_registry.py`: Replace _T TypeVar with ParamSpec

## Motivation
The previous implementation used TypeVar bound to Callable which erased parameter type information to Any. ParamSpec preserves the exact parameter types and return types, providing better type safety and IDE support.

## Testing
- Verified all changes compile and import correctly
- Created comprehensive test suite to validate ParamSpec functionality
- No linting errors introduced
- Maintains backward compatibility

Fixes #142306
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162332
Approved by: https://github.com/Skylion007
This commit is contained in:
Vinayak Pawar 2025-09-07 18:06:03 +00:00 committed by PyTorch MergeBot
parent 7a83cf430e
commit 9ad5e8edb1
3 changed files with 24 additions and 14 deletions

View File

@ -3,17 +3,20 @@
from __future__ import annotations
import functools
from typing import Any, Callable, cast, TypeVar
from typing import Callable, TypeVar
from typing_extensions import ParamSpec
_is_onnx_exporting = False
TCallable = TypeVar("TCallable", bound=Callable[..., Any])
# Use ParamSpec to preserve parameter types instead of erasing to Any
_P = ParamSpec("_P")
_R = TypeVar("_R")
def set_onnx_exporting_flag(func: TCallable) -> TCallable:
def set_onnx_exporting_flag(func: Callable[_P, _R]) -> Callable[_P, _R]:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
global _is_onnx_exporting
_is_onnx_exporting = True
try:
@ -22,4 +25,4 @@ def set_onnx_exporting_flag(func: TCallable) -> TCallable:
# Ensure it resets even if an exception occurs
_is_onnx_exporting = False
return cast(TCallable, wrapper)
return wrapper

View File

@ -10,6 +10,7 @@ __all__ = ["onnx_impl", "get_torchlib_ops"]
import logging
from collections.abc import Sequence
from typing import Any, Callable, TypeVar
from typing_extensions import ParamSpec
import onnxscript
@ -17,7 +18,9 @@ import torch
from torch.onnx._internal.exporter import _constants, _registration
_T = TypeVar("_T", bound=Callable)
# Use ParamSpec for better type preservation instead of bound Callable TypeVar
_P = ParamSpec("_P")
_R = TypeVar("_R")
logger = logging.getLogger("__name__")
@ -33,7 +36,7 @@ def onnx_impl(
opset_introduced: int = 18,
no_compile: bool = False,
private: bool = False,
) -> Callable[[_T], _T]:
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
"""Register an ONNX implementation of a torch op."""
if isinstance(target, torch._ops.OpOverloadPacket):
@ -44,8 +47,8 @@ def onnx_impl(
)
def wrapper(
func: _T,
) -> _T:
func: Callable[_P, _R],
) -> Callable[_P, _R]:
processed_func: Any
if no_compile:
processed_func = func

View File

@ -1,13 +1,15 @@
# flake8: noqa: B950
import math
import typing
from typing import Callable, Optional
from typing import Callable, Optional, TypeVar
from typing_extensions import ParamSpec
import torch
from torch.onnx.ops import _dtype_mappings
_T = typing.TypeVar("_T", bound=Callable)
# Use ParamSpec for better type preservation instead of bound Callable TypeVar
_P = ParamSpec("_P")
_R = TypeVar("_R")
# ONNX to ATen decomp table
ONNX_ATEN_DECOMP_TABLE: dict[torch._ops.OpOverload, Callable] = {}
@ -21,10 +23,12 @@ _ATTENTION_23_ALLOWED_INTERMEDIATE_PRECISIONS = frozenset(
)
def _onnx_op(op_type: str, opset_version: int) -> Callable[[_T], _T]:
def _onnx_op(
op_type: str, opset_version: int
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
"""Decorator to register an ONNX operator with a custom implementation."""
def decorator(func: _T) -> _T:
def decorator(func: Callable[_P, _R]) -> Callable[_P, _R]:
overload = f"opset{opset_version}"
torch_op = torch.library.custom_op(
f"onnx::{op_type}.{overload}", mutates_args=()