[BE] typing for decorators - onnx/symbolic_helper (#131565)

See #131429
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131565
Approved by: https://github.com/justinchuby, https://github.com/oulgen, https://github.com/zou3519, https://github.com/titaiwangms
This commit is contained in:
Aaron Orenstein 2024-07-24 07:57:43 -07:00 committed by PyTorch MergeBot
parent 0e71a88f9b
commit abcd329359
12 changed files with 14 additions and 18 deletions

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
# mypy: disable-error-code=arg-type
import importlib

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from __future__ import annotations
@ -8,7 +7,8 @@ import math
import sys
import typing
import warnings
from typing import Any, Callable, Literal, NoReturn, Sequence
from typing import Any, Callable, Literal, NoReturn, Sequence, TypeVar
from typing_extensions import Concatenate, ParamSpec
import torch
import torch._C._onnx as _C_onnx
@ -22,6 +22,9 @@ from torch.onnx._internal import jit_utils
if typing.TYPE_CHECKING:
from torch.types import Number
_T = TypeVar("_T")
_U = TypeVar("_U")
_P = ParamSpec("_P")
# ---------------------------------------------------------------------------------
# Helper functions
@ -199,7 +202,9 @@ def _is_packed_list(list_value: Any) -> bool:
return _is_value(list_value) and list_value.node().kind() == "prim::ListConstruct"
def parse_args(*arg_descriptors: _ValueDescriptor):
def parse_args(
*arg_descriptors: _ValueDescriptor,
) -> Callable[[Callable[Concatenate[_U, _P], _T]], Callable[Concatenate[_U, _P], _T]]:
"""A decorator which converts args from torch._C.Value to built-in types.
For example:
@ -227,11 +232,13 @@ def parse_args(*arg_descriptors: _ValueDescriptor):
"none": the variable is unused
"""
def decorator(fn):
fn._arg_descriptors = arg_descriptors
def decorator(
fn: Callable[Concatenate[_U, _P], _T]
) -> Callable[Concatenate[_U, _P], _T]:
fn._arg_descriptors = arg_descriptors # type: ignore[attr-defined]
@functools.wraps(fn)
def wrapper(g, *args, **kwargs):
def wrapper(g: _U, *args: _P.args, **kwargs: _P.kwargs) -> _T:
# some args may be optional, so the length may be smaller
FILE_BUG_MSG = (
"If you believe this is not due to custom symbolic implementation within your code or "
@ -282,7 +289,7 @@ def quantized_args(
scale: float | None = None,
zero_point: int | None = None,
quantize_output: bool = True,
):
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
"""A decorator which extends support for quantized version of the base operator.
Quantization is detected by examining the arguments that are annotated by

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
# mypy: disable-error-code=arg-type
from __future__ import annotations

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
# mypy: disable-error-code=arg-type
"""This file exports ONNX ops for opset 11."""

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
# mypy: disable-error-code=arg-type
from __future__ import annotations

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in README.md

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
# mypy: disable-error-code=arg-type
"""This file exports ONNX ops for opset 14.

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
"""This file exports ONNX ops for opset 16.

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
# mypy: disable-error-code=arg-type
"""This file exports ONNX ops for opset 17.

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
"""This file exports ONNX ops for opset 18.

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
"""This file exports ONNX ops for opset 20.

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
"""
Note [ONNX operators that are added/updated from opset 8 to opset 9]