mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Update ONNX documentation to include unsupported operators (#84496)
- Update ONNX documentation to include unsupported operators - Include aten, quantized and other namespaces Pull Request resolved: https://github.com/pytorch/pytorch/pull/84496 Approved by: https://github.com/AllenTiTaiWang, https://github.com/BowenBao, https://github.com/kit1980
This commit is contained in:
parent
46843be1e6
commit
d6c2080eb4
1
.github/merge_rules.yaml
vendored
1
.github/merge_rules.yaml
vendored
|
|
@ -3,6 +3,7 @@
|
||||||
- .jenkins/caffe2/*
|
- .jenkins/caffe2/*
|
||||||
- aten/src/ATen/core/interned_strings.h
|
- aten/src/ATen/core/interned_strings.h
|
||||||
- docs/source/onnx.rst
|
- docs/source/onnx.rst
|
||||||
|
- docs/source/onnx*
|
||||||
- docs/source/scripts/onnx/**
|
- docs/source/scripts/onnx/**
|
||||||
- scripts/onnx/**
|
- scripts/onnx/**
|
||||||
- test/jit/test_export_modes.py
|
- test/jit/test_export_modes.py
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,30 @@
|
||||||
:orphan:
|
:orphan:
|
||||||
|
|
||||||
ONNX supported ATen operators
|
ONNX supported TorchScript operators
|
||||||
=============================
|
====================================
|
||||||
|
|
||||||
This file is automatically generated during the documentation build
|
.. This file is automatically generated during the documentation build
|
||||||
by cross referencing ONNX operator symbolics with Torch JIT operators via
|
.. by cross referencing ONNX operator symbolics with TorchScript operators via
|
||||||
``docs/source/scripts/build_onnx_supported_aten_op_csv_table.py``.
|
.. ``docs/source/scripts/build_onnx_supported_aten_op_csv_table.py``.
|
||||||
Do not modify directly and instead `rebuild the docs <https://github.com/pytorch/pytorch#building-the-documentation>`_.
|
.. Do not modify directly and instead `rebuild the docs <https://github.com/pytorch/pytorch#building-the-documentation>`_.
|
||||||
|
|
||||||
.. csv-table:: Supported ATen operators
|
This page lists the TorchScript operators that are supported/unsupported by ONNX export.
|
||||||
:file: ../build/auto_gen_aten_op_list.csv
|
|
||||||
:widths: 30, 70
|
Supported operators
|
||||||
|
-------------------
|
||||||
|
|
||||||
|
.. csv-table:: ONNX support for TorchScript operators
|
||||||
|
:file: ../build/onnx/auto_gen_supported_op_list.csv
|
||||||
|
:widths: 70, 30
|
||||||
|
:header-rows: 1
|
||||||
|
|
||||||
|
|
||||||
|
Unsupported operators
|
||||||
|
---------------------
|
||||||
|
|
||||||
|
Operators that are not yet supported
|
||||||
|
|
||||||
|
.. csv-table:: Unsupported operators
|
||||||
|
:file: ../build/onnx/auto_gen_unsupported_op_list.csv
|
||||||
|
:widths: 70, 30
|
||||||
:header-rows: 1
|
:header-rows: 1
|
||||||
|
|
|
||||||
|
|
@ -8,18 +8,59 @@ import os
|
||||||
from torch.onnx import _onnx_supported_ops
|
from torch.onnx import _onnx_supported_ops
|
||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
BUILD_DIR = "build"
|
BUILD_DIR = "build/onnx"
|
||||||
AUTO_GEN_ATEN_OPS_CSV_FILE = "auto_gen_aten_op_list.csv"
|
SUPPORTED_OPS_CSV_FILE = "auto_gen_supported_op_list.csv"
|
||||||
|
UNSUPPORTED_OPS_CSV_FILE = "auto_gen_unsupported_op_list.csv"
|
||||||
|
|
||||||
|
|
||||||
|
def _sort_key(namespaced_opname):
|
||||||
|
return tuple(reversed(namespaced_opname.split("::")))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_op_lists():
|
||||||
|
all_schemas = _onnx_supported_ops.all_forward_schemas()
|
||||||
|
symbolic_schemas = _onnx_supported_ops.all_symbolics_schemas()
|
||||||
|
supported_result = set()
|
||||||
|
not_supported_result = set()
|
||||||
|
for opname in all_schemas:
|
||||||
|
if opname.endswith("_"):
|
||||||
|
opname = opname[:-1]
|
||||||
|
if opname in symbolic_schemas:
|
||||||
|
# Supported op
|
||||||
|
opsets = symbolic_schemas[opname].opsets
|
||||||
|
supported_result.add(
|
||||||
|
(
|
||||||
|
opname,
|
||||||
|
f"Since opset {opsets[0]}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Unsupported op
|
||||||
|
not_supported_result.add(
|
||||||
|
(
|
||||||
|
opname,
|
||||||
|
"Not yet supported",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
sorted(supported_result, key=lambda x: _sort_key(x[0])),
|
||||||
|
sorted(not_supported_result),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
os.makedirs(BUILD_DIR, exist_ok=True)
|
os.makedirs(BUILD_DIR, exist_ok=True)
|
||||||
|
|
||||||
aten_list = _onnx_supported_ops.onnx_supported_ops()
|
supported, unsupported = _get_op_lists()
|
||||||
|
|
||||||
with open(os.path.join(BUILD_DIR, AUTO_GEN_ATEN_OPS_CSV_FILE), "w") as f:
|
with open(os.path.join(BUILD_DIR, SUPPORTED_OPS_CSV_FILE), "w") as f:
|
||||||
f.write("Operator,opset_version(s)\n")
|
f.write("Operator,opset_version(s)\n")
|
||||||
for name, opset_version in aten_list:
|
for name, opset_version in supported:
|
||||||
|
f.write(f'"``{name}``","{opset_version}"\n')
|
||||||
|
|
||||||
|
with open(os.path.join(BUILD_DIR, UNSUPPORTED_OPS_CSV_FILE), "w") as f:
|
||||||
|
f.write("Operator,opset_version(s)\n")
|
||||||
|
for name, opset_version in unsupported:
|
||||||
f.write(f'"``{name}``","{opset_version}"\n')
|
f.write(f'"``{name}``","{opset_version}"\n')
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,13 +24,15 @@ class _TorchSchema:
|
||||||
self.opsets = []
|
self.opsets = []
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
s = f"{self.name}.{self.overload_name}("
|
s = (
|
||||||
s += ", ".join(self.arguments)
|
f"{self.name}.{self.overload_name}("
|
||||||
s += ") -> ("
|
+ ", ".join(self.arguments)
|
||||||
s += ", ".join(self.returns)
|
+ ") -> ("
|
||||||
s += ")"
|
+ ", ".join(self.returns)
|
||||||
s += " in opsets "
|
+ ")"
|
||||||
s += ", ".join(str(opset) for opset in self.opsets)
|
+ " in opsets "
|
||||||
|
+ ", ".join(str(opset) for opset in self.opsets)
|
||||||
|
)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
|
|
@ -50,14 +52,6 @@ class _TorchSchema:
|
||||||
return "backward" in self.name
|
return "backward" in self.name
|
||||||
|
|
||||||
|
|
||||||
def _all_aten_forward_schemas():
|
|
||||||
"""Creates a list of _TorchSchema for all aten schemas."""
|
|
||||||
torch_schemas = [_TorchSchema(s) for s in _C._jit_get_all_schemas()]
|
|
||||||
torch_schemas = sorted(torch_schemas, key=lambda x: x.name)
|
|
||||||
aten_schemas = [s for s in torch_schemas if s.is_aten() and not s.is_backward()]
|
|
||||||
return aten_schemas
|
|
||||||
|
|
||||||
|
|
||||||
def _symbolic_argument_count(func):
|
def _symbolic_argument_count(func):
|
||||||
params = []
|
params = []
|
||||||
signature = inspect.signature(func)
|
signature = inspect.signature(func)
|
||||||
|
|
@ -72,7 +66,14 @@ def _symbolic_argument_count(func):
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
def _all_symbolics_schemas() -> Dict[str, _TorchSchema]:
|
def all_forward_schemas() -> Dict[str, _TorchSchema]:
|
||||||
|
"""Returns schemas for all TorchScript forward ops."""
|
||||||
|
torch_schemas = [_TorchSchema(s) for s in _C._jit_get_all_schemas()]
|
||||||
|
return {schema.name: schema for schema in torch_schemas if not schema.is_backward()}
|
||||||
|
|
||||||
|
|
||||||
|
def all_symbolics_schemas() -> Dict[str, _TorchSchema]:
|
||||||
|
"""Returns schemas for all onnx supported ops."""
|
||||||
symbolics_schemas = {}
|
symbolics_schemas = {}
|
||||||
|
|
||||||
for name in registration.registry.all_functions():
|
for name in registration.registry.all_functions():
|
||||||
|
|
@ -94,19 +95,3 @@ def _all_symbolics_schemas() -> Dict[str, _TorchSchema]:
|
||||||
symbolics_schemas[name] = symbolics_schema
|
symbolics_schemas[name] = symbolics_schema
|
||||||
|
|
||||||
return symbolics_schemas
|
return symbolics_schemas
|
||||||
|
|
||||||
|
|
||||||
def onnx_supported_ops():
|
|
||||||
aten_schemas = _all_aten_forward_schemas()
|
|
||||||
symbolic_schemas = _all_symbolics_schemas()
|
|
||||||
torch_schemas = set(symbolic_schemas.values())
|
|
||||||
supported_ops = []
|
|
||||||
onnx_supported = []
|
|
||||||
for schema in aten_schemas:
|
|
||||||
if schema in torch_schemas:
|
|
||||||
opname = schema.name
|
|
||||||
opsets = symbolic_schemas[opname].opsets
|
|
||||||
if schema not in supported_ops:
|
|
||||||
supported_ops.append(symbolic_schemas[opname])
|
|
||||||
onnx_supported.append((opname, " ".join(str(o) for o in opsets)))
|
|
||||||
return sorted(onnx_supported, key=lambda x: x[0])
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user