mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE] typing for decorators - onnx/symbolic_helper (#131565)
See #131429 Pull Request resolved: https://github.com/pytorch/pytorch/pull/131565 Approved by: https://github.com/justinchuby, https://github.com/oulgen, https://github.com/zou3519, https://github.com/titaiwangms
This commit is contained in:
parent
0e71a88f9b
commit
abcd329359
|
|
@ -1,4 +1,3 @@
|
|||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
# mypy: disable-error-code=arg-type
|
||||
import importlib
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
@ -8,7 +7,8 @@ import math
|
|||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
from typing import Any, Callable, Literal, NoReturn, Sequence
|
||||
from typing import Any, Callable, Literal, NoReturn, Sequence, TypeVar
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
import torch
|
||||
import torch._C._onnx as _C_onnx
|
||||
|
|
@ -22,6 +22,9 @@ from torch.onnx._internal import jit_utils
|
|||
if typing.TYPE_CHECKING:
|
||||
from torch.types import Number
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_U = TypeVar("_U")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
# ---------------------------------------------------------------------------------
|
||||
# Helper functions
|
||||
|
|
@ -199,7 +202,9 @@ def _is_packed_list(list_value: Any) -> bool:
|
|||
return _is_value(list_value) and list_value.node().kind() == "prim::ListConstruct"
|
||||
|
||||
|
||||
def parse_args(*arg_descriptors: _ValueDescriptor):
|
||||
def parse_args(
|
||||
*arg_descriptors: _ValueDescriptor,
|
||||
) -> Callable[[Callable[Concatenate[_U, _P], _T]], Callable[Concatenate[_U, _P], _T]]:
|
||||
"""A decorator which converts args from torch._C.Value to built-in types.
|
||||
|
||||
For example:
|
||||
|
|
@ -227,11 +232,13 @@ def parse_args(*arg_descriptors: _ValueDescriptor):
|
|||
"none": the variable is unused
|
||||
"""
|
||||
|
||||
def decorator(fn):
|
||||
fn._arg_descriptors = arg_descriptors
|
||||
def decorator(
|
||||
fn: Callable[Concatenate[_U, _P], _T]
|
||||
) -> Callable[Concatenate[_U, _P], _T]:
|
||||
fn._arg_descriptors = arg_descriptors # type: ignore[attr-defined]
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(g, *args, **kwargs):
|
||||
def wrapper(g: _U, *args: _P.args, **kwargs: _P.kwargs) -> _T:
|
||||
# some args may be optional, so the length may be smaller
|
||||
FILE_BUG_MSG = (
|
||||
"If you believe this is not due to custom symbolic implementation within your code or "
|
||||
|
|
@ -282,7 +289,7 @@ def quantized_args(
|
|||
scale: float | None = None,
|
||||
zero_point: int | None = None,
|
||||
quantize_output: bool = True,
|
||||
):
|
||||
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
||||
"""A decorator which extends support for quantized version of the base operator.
|
||||
|
||||
Quantization is detected by examining the arguments that are annotated by
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
# mypy: disable-error-code=arg-type
|
||||
from __future__ import annotations
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
# mypy: disable-error-code=arg-type
|
||||
"""This file exports ONNX ops for opset 11."""
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
# mypy: disable-error-code=arg-type
|
||||
from __future__ import annotations
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in README.md
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
# mypy: disable-error-code=arg-type
|
||||
"""This file exports ONNX ops for opset 14.
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
"""This file exports ONNX ops for opset 16.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
# mypy: disable-error-code=arg-type
|
||||
"""This file exports ONNX ops for opset 17.
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
"""This file exports ONNX ops for opset 18.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
"""This file exports ONNX ops for opset 20.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
"""
|
||||
Note [ONNX operators that are added/updated from opset 8 to opset 9]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user