[ONNX] Update ONNX documentation to include unsupported operators (#84496)

- Update ONNX documentation to include unsupported operators
- Include aten, quantized and other namespaces
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84496
Approved by: https://github.com/AllenTiTaiWang, https://github.com/BowenBao, https://github.com/kit1980
This commit is contained in:
Justin Chu 2022-09-16 21:56:41 +00:00 committed by PyTorch MergeBot
parent 46843be1e6
commit d6c2080eb4
4 changed files with 89 additions and 46 deletions

View File

@ -3,6 +3,7 @@
- .jenkins/caffe2/* - .jenkins/caffe2/*
- aten/src/ATen/core/interned_strings.h - aten/src/ATen/core/interned_strings.h
- docs/source/onnx.rst - docs/source/onnx.rst
- docs/source/onnx*
- docs/source/scripts/onnx/** - docs/source/scripts/onnx/**
- scripts/onnx/** - scripts/onnx/**
- test/jit/test_export_modes.py - test/jit/test_export_modes.py

View File

@ -1,14 +1,30 @@
:orphan: :orphan:
ONNX supported ATen operators ONNX supported TorchScript operators
============================= ====================================
This file is automatically generated during the documentation build .. This file is automatically generated during the documentation build
by cross referencing ONNX operator symbolics with Torch JIT operators via .. by cross referencing ONNX operator symbolics with TorchScript operators via
``docs/source/scripts/build_onnx_supported_aten_op_csv_table.py``. .. ``docs/source/scripts/build_onnx_supported_aten_op_csv_table.py``.
Do not modify directly and instead `rebuild the docs <https://github.com/pytorch/pytorch#building-the-documentation>`_. .. Do not modify directly and instead `rebuild the docs <https://github.com/pytorch/pytorch#building-the-documentation>`_.
.. csv-table:: Supported ATen operators This page lists the TorchScript operators that are supported/unsupported by ONNX export.
:file: ../build/auto_gen_aten_op_list.csv
:widths: 30, 70 Supported operators
-------------------
.. csv-table:: ONNX support for TorchScript operators
:file: ../build/onnx/auto_gen_supported_op_list.csv
:widths: 70, 30
:header-rows: 1
Unsupported operators
---------------------
Operators that are not yet supported
.. csv-table:: Unsupported operators
:file: ../build/onnx/auto_gen_unsupported_op_list.csv
:widths: 70, 30
:header-rows: 1 :header-rows: 1

View File

@ -8,18 +8,59 @@ import os
from torch.onnx import _onnx_supported_ops from torch.onnx import _onnx_supported_ops
# Constants # Constants
BUILD_DIR = "build" BUILD_DIR = "build/onnx"
AUTO_GEN_ATEN_OPS_CSV_FILE = "auto_gen_aten_op_list.csv" SUPPORTED_OPS_CSV_FILE = "auto_gen_supported_op_list.csv"
UNSUPPORTED_OPS_CSV_FILE = "auto_gen_unsupported_op_list.csv"
def _sort_key(namespaced_opname):
return tuple(reversed(namespaced_opname.split("::")))
def _get_op_lists():
all_schemas = _onnx_supported_ops.all_forward_schemas()
symbolic_schemas = _onnx_supported_ops.all_symbolics_schemas()
supported_result = set()
not_supported_result = set()
for opname in all_schemas:
if opname.endswith("_"):
opname = opname[:-1]
if opname in symbolic_schemas:
# Supported op
opsets = symbolic_schemas[opname].opsets
supported_result.add(
(
opname,
f"Since opset {opsets[0]}",
)
)
else:
# Unsupported op
not_supported_result.add(
(
opname,
"Not yet supported",
)
)
return (
sorted(supported_result, key=lambda x: _sort_key(x[0])),
sorted(not_supported_result),
)
def main(): def main():
os.makedirs(BUILD_DIR, exist_ok=True) os.makedirs(BUILD_DIR, exist_ok=True)
aten_list = _onnx_supported_ops.onnx_supported_ops() supported, unsupported = _get_op_lists()
with open(os.path.join(BUILD_DIR, AUTO_GEN_ATEN_OPS_CSV_FILE), "w") as f: with open(os.path.join(BUILD_DIR, SUPPORTED_OPS_CSV_FILE), "w") as f:
f.write("Operator,opset_version(s)\n") f.write("Operator,opset_version(s)\n")
for name, opset_version in aten_list: for name, opset_version in supported:
f.write(f'"``{name}``","{opset_version}"\n')
with open(os.path.join(BUILD_DIR, UNSUPPORTED_OPS_CSV_FILE), "w") as f:
f.write("Operator,opset_version(s)\n")
for name, opset_version in unsupported:
f.write(f'"``{name}``","{opset_version}"\n') f.write(f'"``{name}``","{opset_version}"\n')

View File

@ -24,13 +24,15 @@ class _TorchSchema:
self.opsets = [] self.opsets = []
def __str__(self) -> str: def __str__(self) -> str:
s = f"{self.name}.{self.overload_name}(" s = (
s += ", ".join(self.arguments) f"{self.name}.{self.overload_name}("
s += ") -> (" + ", ".join(self.arguments)
s += ", ".join(self.returns) + ") -> ("
s += ")" + ", ".join(self.returns)
s += " in opsets " + ")"
s += ", ".join(str(opset) for opset in self.opsets) + " in opsets "
+ ", ".join(str(opset) for opset in self.opsets)
)
return s return s
def __hash__(self): def __hash__(self):
@ -50,14 +52,6 @@ class _TorchSchema:
return "backward" in self.name return "backward" in self.name
def _all_aten_forward_schemas():
"""Creates a list of _TorchSchema for all aten schemas."""
torch_schemas = [_TorchSchema(s) for s in _C._jit_get_all_schemas()]
torch_schemas = sorted(torch_schemas, key=lambda x: x.name)
aten_schemas = [s for s in torch_schemas if s.is_aten() and not s.is_backward()]
return aten_schemas
def _symbolic_argument_count(func): def _symbolic_argument_count(func):
params = [] params = []
signature = inspect.signature(func) signature = inspect.signature(func)
@ -72,7 +66,14 @@ def _symbolic_argument_count(func):
return params return params
def _all_symbolics_schemas() -> Dict[str, _TorchSchema]: def all_forward_schemas() -> Dict[str, _TorchSchema]:
"""Returns schemas for all TorchScript forward ops."""
torch_schemas = [_TorchSchema(s) for s in _C._jit_get_all_schemas()]
return {schema.name: schema for schema in torch_schemas if not schema.is_backward()}
def all_symbolics_schemas() -> Dict[str, _TorchSchema]:
"""Returns schemas for all onnx supported ops."""
symbolics_schemas = {} symbolics_schemas = {}
for name in registration.registry.all_functions(): for name in registration.registry.all_functions():
@ -94,19 +95,3 @@ def _all_symbolics_schemas() -> Dict[str, _TorchSchema]:
symbolics_schemas[name] = symbolics_schema symbolics_schemas[name] = symbolics_schema
return symbolics_schemas return symbolics_schemas
def onnx_supported_ops():
aten_schemas = _all_aten_forward_schemas()
symbolic_schemas = _all_symbolics_schemas()
torch_schemas = set(symbolic_schemas.values())
supported_ops = []
onnx_supported = []
for schema in aten_schemas:
if schema in torch_schemas:
opname = schema.name
opsets = symbolic_schemas[opname].opsets
if schema not in supported_ops:
supported_ops.append(symbolic_schemas[opname])
onnx_supported.append((opname, " ".join(str(o) for o in opsets)))
return sorted(onnx_supported, key=lambda x: x[0])