mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Update ONNX documentation to include unsupported operators (#84496)
- Update ONNX documentation to include unsupported operators - Include aten, quantized and other namespaces Pull Request resolved: https://github.com/pytorch/pytorch/pull/84496 Approved by: https://github.com/AllenTiTaiWang, https://github.com/BowenBao, https://github.com/kit1980
This commit is contained in:
parent
46843be1e6
commit
d6c2080eb4
1
.github/merge_rules.yaml
vendored
1
.github/merge_rules.yaml
vendored
|
|
@ -3,6 +3,7 @@
|
|||
- .jenkins/caffe2/*
|
||||
- aten/src/ATen/core/interned_strings.h
|
||||
- docs/source/onnx.rst
|
||||
- docs/source/onnx*
|
||||
- docs/source/scripts/onnx/**
|
||||
- scripts/onnx/**
|
||||
- test/jit/test_export_modes.py
|
||||
|
|
|
|||
|
|
@ -1,14 +1,30 @@
|
|||
:orphan:
|
||||
|
||||
ONNX supported ATen operators
|
||||
=============================
|
||||
ONNX supported TorchScript operators
|
||||
====================================
|
||||
|
||||
This file is automatically generated during the documentation build
|
||||
by cross referencing ONNX operator symbolics with Torch JIT operators via
|
||||
``docs/source/scripts/build_onnx_supported_aten_op_csv_table.py``.
|
||||
Do not modify directly and instead `rebuild the docs <https://github.com/pytorch/pytorch#building-the-documentation>`_.
|
||||
.. This file is automatically generated during the documentation build
|
||||
.. by cross referencing ONNX operator symbolics with TorchScript operators via
|
||||
.. ``docs/source/scripts/build_onnx_supported_aten_op_csv_table.py``.
|
||||
.. Do not modify directly and instead `rebuild the docs <https://github.com/pytorch/pytorch#building-the-documentation>`_.
|
||||
|
||||
.. csv-table:: Supported ATen operators
|
||||
:file: ../build/auto_gen_aten_op_list.csv
|
||||
:widths: 30, 70
|
||||
This page lists the TorchScript operators that are supported/unsupported by ONNX export.
|
||||
|
||||
Supported operators
|
||||
-------------------
|
||||
|
||||
.. csv-table:: ONNX support for TorchScript operators
|
||||
:file: ../build/onnx/auto_gen_supported_op_list.csv
|
||||
:widths: 70, 30
|
||||
:header-rows: 1
|
||||
|
||||
|
||||
Unsupported operators
|
||||
---------------------
|
||||
|
||||
Operators that are not yet supported
|
||||
|
||||
.. csv-table:: Unsupported operators
|
||||
:file: ../build/onnx/auto_gen_unsupported_op_list.csv
|
||||
:widths: 70, 30
|
||||
:header-rows: 1
|
||||
|
|
|
|||
|
|
@ -8,18 +8,59 @@ import os
|
|||
from torch.onnx import _onnx_supported_ops
|
||||
|
||||
# Constants
|
||||
BUILD_DIR = "build"
|
||||
AUTO_GEN_ATEN_OPS_CSV_FILE = "auto_gen_aten_op_list.csv"
|
||||
BUILD_DIR = "build/onnx"
|
||||
SUPPORTED_OPS_CSV_FILE = "auto_gen_supported_op_list.csv"
|
||||
UNSUPPORTED_OPS_CSV_FILE = "auto_gen_unsupported_op_list.csv"
|
||||
|
||||
|
||||
def _sort_key(namespaced_opname):
|
||||
return tuple(reversed(namespaced_opname.split("::")))
|
||||
|
||||
|
||||
def _get_op_lists():
|
||||
all_schemas = _onnx_supported_ops.all_forward_schemas()
|
||||
symbolic_schemas = _onnx_supported_ops.all_symbolics_schemas()
|
||||
supported_result = set()
|
||||
not_supported_result = set()
|
||||
for opname in all_schemas:
|
||||
if opname.endswith("_"):
|
||||
opname = opname[:-1]
|
||||
if opname in symbolic_schemas:
|
||||
# Supported op
|
||||
opsets = symbolic_schemas[opname].opsets
|
||||
supported_result.add(
|
||||
(
|
||||
opname,
|
||||
f"Since opset {opsets[0]}",
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Unsupported op
|
||||
not_supported_result.add(
|
||||
(
|
||||
opname,
|
||||
"Not yet supported",
|
||||
)
|
||||
)
|
||||
return (
|
||||
sorted(supported_result, key=lambda x: _sort_key(x[0])),
|
||||
sorted(not_supported_result),
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
os.makedirs(BUILD_DIR, exist_ok=True)
|
||||
|
||||
aten_list = _onnx_supported_ops.onnx_supported_ops()
|
||||
supported, unsupported = _get_op_lists()
|
||||
|
||||
with open(os.path.join(BUILD_DIR, AUTO_GEN_ATEN_OPS_CSV_FILE), "w") as f:
|
||||
with open(os.path.join(BUILD_DIR, SUPPORTED_OPS_CSV_FILE), "w") as f:
|
||||
f.write("Operator,opset_version(s)\n")
|
||||
for name, opset_version in aten_list:
|
||||
for name, opset_version in supported:
|
||||
f.write(f'"``{name}``","{opset_version}"\n')
|
||||
|
||||
with open(os.path.join(BUILD_DIR, UNSUPPORTED_OPS_CSV_FILE), "w") as f:
|
||||
f.write("Operator,opset_version(s)\n")
|
||||
for name, opset_version in unsupported:
|
||||
f.write(f'"``{name}``","{opset_version}"\n')
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -24,13 +24,15 @@ class _TorchSchema:
|
|||
self.opsets = []
|
||||
|
||||
def __str__(self) -> str:
|
||||
s = f"{self.name}.{self.overload_name}("
|
||||
s += ", ".join(self.arguments)
|
||||
s += ") -> ("
|
||||
s += ", ".join(self.returns)
|
||||
s += ")"
|
||||
s += " in opsets "
|
||||
s += ", ".join(str(opset) for opset in self.opsets)
|
||||
s = (
|
||||
f"{self.name}.{self.overload_name}("
|
||||
+ ", ".join(self.arguments)
|
||||
+ ") -> ("
|
||||
+ ", ".join(self.returns)
|
||||
+ ")"
|
||||
+ " in opsets "
|
||||
+ ", ".join(str(opset) for opset in self.opsets)
|
||||
)
|
||||
return s
|
||||
|
||||
def __hash__(self):
|
||||
|
|
@ -50,14 +52,6 @@ class _TorchSchema:
|
|||
return "backward" in self.name
|
||||
|
||||
|
||||
def _all_aten_forward_schemas():
|
||||
"""Creates a list of _TorchSchema for all aten schemas."""
|
||||
torch_schemas = [_TorchSchema(s) for s in _C._jit_get_all_schemas()]
|
||||
torch_schemas = sorted(torch_schemas, key=lambda x: x.name)
|
||||
aten_schemas = [s for s in torch_schemas if s.is_aten() and not s.is_backward()]
|
||||
return aten_schemas
|
||||
|
||||
|
||||
def _symbolic_argument_count(func):
|
||||
params = []
|
||||
signature = inspect.signature(func)
|
||||
|
|
@ -72,7 +66,14 @@ def _symbolic_argument_count(func):
|
|||
return params
|
||||
|
||||
|
||||
def _all_symbolics_schemas() -> Dict[str, _TorchSchema]:
|
||||
def all_forward_schemas() -> Dict[str, _TorchSchema]:
|
||||
"""Returns schemas for all TorchScript forward ops."""
|
||||
torch_schemas = [_TorchSchema(s) for s in _C._jit_get_all_schemas()]
|
||||
return {schema.name: schema for schema in torch_schemas if not schema.is_backward()}
|
||||
|
||||
|
||||
def all_symbolics_schemas() -> Dict[str, _TorchSchema]:
|
||||
"""Returns schemas for all onnx supported ops."""
|
||||
symbolics_schemas = {}
|
||||
|
||||
for name in registration.registry.all_functions():
|
||||
|
|
@ -94,19 +95,3 @@ def _all_symbolics_schemas() -> Dict[str, _TorchSchema]:
|
|||
symbolics_schemas[name] = symbolics_schema
|
||||
|
||||
return symbolics_schemas
|
||||
|
||||
|
||||
def onnx_supported_ops():
|
||||
aten_schemas = _all_aten_forward_schemas()
|
||||
symbolic_schemas = _all_symbolics_schemas()
|
||||
torch_schemas = set(symbolic_schemas.values())
|
||||
supported_ops = []
|
||||
onnx_supported = []
|
||||
for schema in aten_schemas:
|
||||
if schema in torch_schemas:
|
||||
opname = schema.name
|
||||
opsets = symbolic_schemas[opname].opsets
|
||||
if schema not in supported_ops:
|
||||
supported_ops.append(symbolic_schemas[opname])
|
||||
onnx_supported.append((opname, " ".join(str(o) for o in opsets)))
|
||||
return sorted(onnx_supported, key=lambda x: x[0])
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user