[ONNX] Update ONNX documentation to include unsupported operators (#84496)

- Update ONNX documentation to include unsupported operators
- Include aten, quantized and other namespaces
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84496
Approved by: https://github.com/AllenTiTaiWang, https://github.com/BowenBao, https://github.com/kit1980
This commit is contained in:
Justin Chu 2022-09-16 21:56:41 +00:00 committed by PyTorch MergeBot
parent 46843be1e6
commit d6c2080eb4
4 changed files with 89 additions and 46 deletions

View File

@ -3,6 +3,7 @@
- .jenkins/caffe2/*
- aten/src/ATen/core/interned_strings.h
- docs/source/onnx.rst
- docs/source/onnx*
- docs/source/scripts/onnx/**
- scripts/onnx/**
- test/jit/test_export_modes.py

View File

@ -1,14 +1,30 @@
:orphan:
ONNX supported ATen operators
=============================
ONNX supported TorchScript operators
====================================
This file is automatically generated during the documentation build
by cross referencing ONNX operator symbolics with Torch JIT operators via
``docs/source/scripts/build_onnx_supported_aten_op_csv_table.py``.
Do not modify directly and instead `rebuild the docs <https://github.com/pytorch/pytorch#building-the-documentation>`_.
.. This file is automatically generated during the documentation build
.. by cross referencing ONNX operator symbolics with TorchScript operators via
.. ``docs/source/scripts/build_onnx_supported_aten_op_csv_table.py``.
.. Do not modify directly and instead `rebuild the docs <https://github.com/pytorch/pytorch#building-the-documentation>`_.
.. csv-table:: Supported ATen operators
:file: ../build/auto_gen_aten_op_list.csv
:widths: 30, 70
This page lists the TorchScript operators that are supported/unsupported by ONNX export.
Supported operators
-------------------
.. csv-table:: ONNX support for TorchScript operators
:file: ../build/onnx/auto_gen_supported_op_list.csv
:widths: 70, 30
:header-rows: 1
Unsupported operators
---------------------
Operators that are not yet supported
.. csv-table:: Unsupported operators
:file: ../build/onnx/auto_gen_unsupported_op_list.csv
:widths: 70, 30
:header-rows: 1

View File

@ -8,18 +8,59 @@ import os
from torch.onnx import _onnx_supported_ops
# Constants
BUILD_DIR = "build"
AUTO_GEN_ATEN_OPS_CSV_FILE = "auto_gen_aten_op_list.csv"
BUILD_DIR = "build/onnx"
SUPPORTED_OPS_CSV_FILE = "auto_gen_supported_op_list.csv"
UNSUPPORTED_OPS_CSV_FILE = "auto_gen_unsupported_op_list.csv"
def _sort_key(namespaced_opname):
return tuple(reversed(namespaced_opname.split("::")))
def _get_op_lists():
all_schemas = _onnx_supported_ops.all_forward_schemas()
symbolic_schemas = _onnx_supported_ops.all_symbolics_schemas()
supported_result = set()
not_supported_result = set()
for opname in all_schemas:
if opname.endswith("_"):
opname = opname[:-1]
if opname in symbolic_schemas:
# Supported op
opsets = symbolic_schemas[opname].opsets
supported_result.add(
(
opname,
f"Since opset {opsets[0]}",
)
)
else:
# Unsupported op
not_supported_result.add(
(
opname,
"Not yet supported",
)
)
return (
sorted(supported_result, key=lambda x: _sort_key(x[0])),
sorted(not_supported_result),
)
def main():
os.makedirs(BUILD_DIR, exist_ok=True)
aten_list = _onnx_supported_ops.onnx_supported_ops()
supported, unsupported = _get_op_lists()
with open(os.path.join(BUILD_DIR, AUTO_GEN_ATEN_OPS_CSV_FILE), "w") as f:
with open(os.path.join(BUILD_DIR, SUPPORTED_OPS_CSV_FILE), "w") as f:
f.write("Operator,opset_version(s)\n")
for name, opset_version in aten_list:
for name, opset_version in supported:
f.write(f'"``{name}``","{opset_version}"\n')
with open(os.path.join(BUILD_DIR, UNSUPPORTED_OPS_CSV_FILE), "w") as f:
f.write("Operator,opset_version(s)\n")
for name, opset_version in unsupported:
f.write(f'"``{name}``","{opset_version}"\n')

View File

@ -24,13 +24,15 @@ class _TorchSchema:
self.opsets = []
def __str__(self) -> str:
s = f"{self.name}.{self.overload_name}("
s += ", ".join(self.arguments)
s += ") -> ("
s += ", ".join(self.returns)
s += ")"
s += " in opsets "
s += ", ".join(str(opset) for opset in self.opsets)
s = (
f"{self.name}.{self.overload_name}("
+ ", ".join(self.arguments)
+ ") -> ("
+ ", ".join(self.returns)
+ ")"
+ " in opsets "
+ ", ".join(str(opset) for opset in self.opsets)
)
return s
def __hash__(self):
@ -50,14 +52,6 @@ class _TorchSchema:
return "backward" in self.name
def _all_aten_forward_schemas():
"""Creates a list of _TorchSchema for all aten schemas."""
torch_schemas = [_TorchSchema(s) for s in _C._jit_get_all_schemas()]
torch_schemas = sorted(torch_schemas, key=lambda x: x.name)
aten_schemas = [s for s in torch_schemas if s.is_aten() and not s.is_backward()]
return aten_schemas
def _symbolic_argument_count(func):
params = []
signature = inspect.signature(func)
@ -72,7 +66,14 @@ def _symbolic_argument_count(func):
return params
def _all_symbolics_schemas() -> Dict[str, _TorchSchema]:
def all_forward_schemas() -> Dict[str, _TorchSchema]:
"""Returns schemas for all TorchScript forward ops."""
torch_schemas = [_TorchSchema(s) for s in _C._jit_get_all_schemas()]
return {schema.name: schema for schema in torch_schemas if not schema.is_backward()}
def all_symbolics_schemas() -> Dict[str, _TorchSchema]:
"""Returns schemas for all onnx supported ops."""
symbolics_schemas = {}
for name in registration.registry.all_functions():
@ -94,19 +95,3 @@ def _all_symbolics_schemas() -> Dict[str, _TorchSchema]:
symbolics_schemas[name] = symbolics_schema
return symbolics_schemas
def onnx_supported_ops():
aten_schemas = _all_aten_forward_schemas()
symbolic_schemas = _all_symbolics_schemas()
torch_schemas = set(symbolic_schemas.values())
supported_ops = []
onnx_supported = []
for schema in aten_schemas:
if schema in torch_schemas:
opname = schema.name
opsets = symbolic_schemas[opname].opsets
if schema not in supported_ops:
supported_ops.append(symbolic_schemas[opname])
onnx_supported.append((opname, " ".join(str(o) for o in opsets)))
return sorted(onnx_supported, key=lambda x: x[0])