mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Use pyupgrade(https://github.com/asottile/pyupgrade) and flynt to modernize python syntax ```sh pyupgrade --py36-plus --keep-runtime-typing torch/onnx/**/*.py pyupgrade --py36-plus --keep-runtime-typing test/onnx/**/*.py flynt torch/onnx/ --line-length 120 ``` - Use f-strings for string formatting - Use the new `super()` syntax for class initialization - Use dictionary / set comprehension Pull Request resolved: https://github.com/pytorch/pytorch/pull/77935 Approved by: https://github.com/BowenBao
174 lines
6.7 KiB
Python
174 lines
6.7 KiB
Python
import importlib
|
|
import inspect
|
|
import itertools
|
|
import warnings
|
|
from typing import Any, Callable, Dict, Tuple, Union
|
|
|
|
from torch import _C
|
|
from torch.onnx import _constants
|
|
|
|
_SymbolicFunction = Callable[..., Union[_C.Value, Tuple[_C.Value]]]
|
|
|
|
"""
|
|
The symbolic registry "_registry" is a dictionary that maps operators
|
|
(for a specific domain and opset version) to their symbolic functions.
|
|
An operator is defined by its domain, opset version, and opname.
|
|
The keys are tuples (domain, version), (where domain is a string, and version is an int),
|
|
and the operator's name (string).
|
|
The map's entries are as follows : _registry[(domain, version)][op_name] = op_symbolic
|
|
"""
|
|
_registry: Dict[
|
|
Tuple[str, int],
|
|
Dict[str, _SymbolicFunction],
|
|
] = {}
|
|
|
|
_symbolic_versions: Dict[Union[int, str], Any] = {}
|
|
|
|
|
|
def _import_symbolic_opsets():
|
|
for opset_version in itertools.chain(
|
|
_constants.onnx_stable_opsets, [_constants.onnx_main_opset]
|
|
):
|
|
module = importlib.import_module(f"torch.onnx.symbolic_opset{opset_version}")
|
|
global _symbolic_versions
|
|
_symbolic_versions[opset_version] = module
|
|
|
|
|
|
def register_version(domain: str, version: int):
|
|
if not is_registered_version(domain, version):
|
|
global _registry
|
|
_registry[(domain, version)] = {}
|
|
register_ops_in_version(domain, version)
|
|
|
|
|
|
def register_ops_helper(domain: str, version: int, iter_version: int):
|
|
for domain, op_name, op_func in get_ops_in_version(iter_version):
|
|
if not is_registered_op(op_name, domain, version):
|
|
register_op(op_name, op_func, domain, version)
|
|
|
|
|
|
def register_ops_in_version(domain: str, version: int):
|
|
# iterates through the symbolic functions of
|
|
# the specified opset version, and the previous
|
|
# opset versions for operators supported in
|
|
# previous versions.
|
|
|
|
# Opset 9 is the base version. It is selected as the base version because
|
|
# 1. It is the first opset version supported by PyTorch export.
|
|
# 2. opset 9 is more robust than previous opset versions. Opset versions like 7/8 have limitations
|
|
# that certain basic operators cannot be expressed in ONNX. Instead of basing on these limitations,
|
|
# we chose to handle them as special cases separately.
|
|
# Backward support for opset versions beyond opset 7 is not in our roadmap.
|
|
|
|
# For opset versions other than 9, by default they will inherit the symbolic functions defined in
|
|
# symbolic_opset9.py.
|
|
# To extend support for updated operators in different opset versions on top of opset 9,
|
|
# simply add the updated symbolic functions in the respective symbolic_opset{version}.py file.
|
|
# Checkout topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.py for example.
|
|
iter_version = version
|
|
while iter_version != 9:
|
|
register_ops_helper(domain, version, iter_version)
|
|
if iter_version > 9:
|
|
iter_version = iter_version - 1
|
|
else:
|
|
iter_version = iter_version + 1
|
|
|
|
register_ops_helper(domain, version, 9)
|
|
|
|
|
|
def get_ops_in_version(version: int):
|
|
if not _symbolic_versions:
|
|
_import_symbolic_opsets()
|
|
members = inspect.getmembers(_symbolic_versions[version])
|
|
domain_opname_ops = []
|
|
for obj in members:
|
|
if isinstance(obj[1], type) and hasattr(obj[1], "domain"):
|
|
ops = inspect.getmembers(obj[1], predicate=inspect.isfunction)
|
|
for op in ops:
|
|
domain_opname_ops.append((obj[1].domain, op[0], op[1])) # type: ignore[attr-defined]
|
|
|
|
elif inspect.isfunction(obj[1]):
|
|
if obj[0] == "_len":
|
|
obj = ("len", obj[1])
|
|
if obj[0] == "_list":
|
|
obj = ("list", obj[1])
|
|
if obj[0] == "_any":
|
|
obj = ("any", obj[1])
|
|
if obj[0] == "_all":
|
|
obj = ("all", obj[1])
|
|
domain_opname_ops.append(("", obj[0], obj[1]))
|
|
return domain_opname_ops
|
|
|
|
|
|
def is_registered_version(domain: str, version: int):
|
|
global _registry
|
|
return (domain, version) in _registry
|
|
|
|
|
|
def register_op(opname, op, domain, version):
|
|
if domain is None or version is None:
|
|
warnings.warn(
|
|
"ONNX export failed. The ONNX domain and/or version to register are None."
|
|
)
|
|
global _registry
|
|
if not is_registered_version(domain, version):
|
|
_registry[(domain, version)] = {}
|
|
_registry[(domain, version)][opname] = op
|
|
|
|
|
|
def is_registered_op(opname: str, domain: str, version: int):
|
|
if domain is None or version is None:
|
|
warnings.warn("ONNX export failed. The ONNX domain and/or version are None.")
|
|
global _registry
|
|
return (domain, version) in _registry and opname in _registry[(domain, version)]
|
|
|
|
|
|
def unregister_op(opname: str, domain: str, version: int):
|
|
global _registry
|
|
if is_registered_op(opname, domain, version):
|
|
del _registry[(domain, version)][opname]
|
|
if not _registry[(domain, version)]:
|
|
del _registry[(domain, version)]
|
|
else:
|
|
warnings.warn("The opname " + opname + " is not registered.")
|
|
|
|
|
|
def get_op_supported_version(opname: str, domain: str, version: int):
|
|
iter_version = version
|
|
while iter_version <= _constants.onnx_main_opset:
|
|
ops = [(op[0], op[1]) for op in get_ops_in_version(iter_version)]
|
|
if (domain, opname) in ops:
|
|
return iter_version
|
|
iter_version += 1
|
|
return None
|
|
|
|
|
|
def get_registered_op(opname: str, domain: str, version: int) -> _SymbolicFunction:
|
|
if domain is None or version is None:
|
|
warnings.warn("ONNX export failed. The ONNX domain and/or version are None.")
|
|
global _registry
|
|
if not is_registered_op(opname, domain, version):
|
|
raise UnsupportedOperatorError(domain, opname, version)
|
|
return _registry[(domain, version)][opname]
|
|
|
|
|
|
class UnsupportedOperatorError(RuntimeError):
|
|
def __init__(self, domain: str, opname: str, version: int):
|
|
supported_version = get_op_supported_version(opname, domain, version)
|
|
if domain in {"", "aten", "prim", "quantized"}:
|
|
msg = f"Exporting the operator {domain}::{opname} to ONNX opset version {version} is not supported. "
|
|
if supported_version is not None:
|
|
msg += (
|
|
f"Support for this operator was added in version {supported_version}, "
|
|
"try exporting with this version."
|
|
)
|
|
else:
|
|
msg += "Please feel free to request support or submit a pull request on PyTorch GitHub."
|
|
else:
|
|
msg = (
|
|
f"ONNX export failed on an operator with unrecognized namespace {domain}::{opname}. "
|
|
"If you are trying to export a custom operator, make sure you registered "
|
|
"it with the right domain and version."
|
|
)
|
|
super().__init__(msg)
|