mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: This is an extension to the original PR https://github.com/pytorch/pytorch/pull/21765 1. Increase the coverage of different opsets support, comments, and blacklisting. 2. Adding backend tests for both caffe2 and onnxruntime on opset 7 and opset 8. 3. Reusing onnx model tests in caffe2 for onnxruntime. Pull Request resolved: https://github.com/pytorch/pytorch/pull/22421 Reviewed By: zrphercule Differential Revision: D16225518 Pulled By: houseroad fbshipit-source-id: 01ae3eed85111a83a0124e9e95512b80109d6aee
92 lines
3.7 KiB
Python
92 lines
3.7 KiB
Python
import warnings
|
|
import importlib
|
|
from inspect import getmembers, isfunction
|
|
|
|
# The symbolic registry "_registry" is a dictionary that maps operators
|
|
# (for a specific domain and opset version) to their symbolic functions.
|
|
# An operator is defined by its domain, opset version, and opname.
|
|
# The keys are tuples (domain, version), (where domain is a string, and version is an int),
|
|
# and the operator's name (string).
|
|
# The map's entries are as follows : _registry[(domain, version)][op_name] = op_symbolic
|
|
_registry = {}
|
|
|
|
_symbolic_versions = {}
|
|
from torch.onnx.symbolic_helper import _onnx_stable_opsets
|
|
for opset_version in _onnx_stable_opsets:
|
|
module = importlib.import_module('torch.onnx.symbolic_opset{}'.format(opset_version))
|
|
_symbolic_versions[opset_version] = module
|
|
|
|
def register_version(domain, version):
|
|
if not is_registered_version(domain, version):
|
|
global _registry
|
|
_registry[(domain, version)] = {}
|
|
register_ops_in_version(domain, version)
|
|
|
|
|
|
def register_ops_helper(domain, version, iter_version):
|
|
version_ops = get_ops_in_version(iter_version)
|
|
for op in version_ops:
|
|
if isfunction(op[1]) and not is_registered_op(op[0], domain, version):
|
|
register_op(op[0], op[1], domain, version)
|
|
|
|
|
|
def register_ops_in_version(domain, version):
|
|
# iterates through the symbolic functions of
|
|
# the specified opset version, and the previous
|
|
# opset versions for operators supported in
|
|
# previous versions.
|
|
|
|
# Opset 9 is the base version. It is selected as the base version because
|
|
# 1. It is the first opset version supported by PyTorch export.
|
|
# 2. opset 9 is more robust than previous opset versions. Opset versions like 7/8 have limitations
|
|
# that certain basic operators cannot be expressed in ONNX. Instead of basing on these limitations,
|
|
# we chose to handle them as special cases separately.
|
|
# Backward support for opset versions beyond opset 7 is not in our roadmap.
|
|
|
|
# For opset versions other than 9, by default they will inherit the symbolic functions defined in
|
|
# symbolic_opset9.py.
|
|
# To extend support for updated operators in different opset versions on top of opset 9,
|
|
# simply add the updated symbolic functions in the respective symbolic_opset{version}.py file.
|
|
# Checkout topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.py for example.
|
|
iter_version = version
|
|
while iter_version != 9:
|
|
register_ops_helper(domain, version, iter_version)
|
|
if iter_version > 9:
|
|
iter_version = iter_version - 1
|
|
else:
|
|
iter_version = iter_version + 1
|
|
|
|
register_ops_helper(domain, version, 9)
|
|
|
|
|
|
def get_ops_in_version(version):
|
|
return getmembers(_symbolic_versions[version])
|
|
|
|
|
|
def is_registered_version(domain, version):
|
|
global _registry
|
|
return (domain, version) in _registry
|
|
|
|
|
|
def register_op(opname, op, domain, version):
|
|
if domain is None or version is None:
|
|
warnings.warn("ONNX export failed. The ONNX domain and/or version to register are None.")
|
|
global _registry
|
|
if not is_registered_version(domain, version):
|
|
_registry[(domain, version)] = {}
|
|
_registry[(domain, version)][opname] = op
|
|
|
|
|
|
def is_registered_op(opname, domain, version):
|
|
if domain is None or version is None:
|
|
warnings.warn("ONNX export failed. The ONNX domain and/or version are None.")
|
|
global _registry
|
|
return (domain, version) in _registry and opname in _registry[(domain, version)]
|
|
|
|
|
|
def get_registered_op(opname, domain, version):
|
|
if domain is None or version is None:
|
|
warnings.warn("ONNX export failed. The ONNX domain and/or version are None.")
|
|
global _registry
|
|
return _registry[(domain, version)][opname]
|