diff --git a/.github/merge_rules.yaml b/.github/merge_rules.yaml index e7c56191615..7f091a55c62 100644 --- a/.github/merge_rules.yaml +++ b/.github/merge_rules.yaml @@ -3,6 +3,7 @@ - .jenkins/caffe2/* - aten/src/ATen/core/interned_strings.h - docs/source/onnx.rst + - docs/source/onnx* - docs/source/scripts/onnx/** - scripts/onnx/** - test/jit/test_export_modes.py diff --git a/docs/source/onnx_supported_aten_ops.rst b/docs/source/onnx_supported_aten_ops.rst index d6bf535e2e7..ce075b59a7b 100644 --- a/docs/source/onnx_supported_aten_ops.rst +++ b/docs/source/onnx_supported_aten_ops.rst @@ -1,14 +1,30 @@ :orphan: -ONNX supported ATen operators -============================= +ONNX supported TorchScript operators +==================================== -This file is automatically generated during the documentation build -by cross referencing ONNX operator symbolics with Torch JIT operators via -``docs/source/scripts/build_onnx_supported_aten_op_csv_table.py``. -Do not modify directly and instead `rebuild the docs `_. +.. This file is automatically generated during the documentation build +.. by cross referencing ONNX operator symbolics with TorchScript operators via +.. ``docs/source/scripts/build_onnx_supported_aten_op_csv_table.py``. +.. Do not modify directly and instead `rebuild the docs `_. -.. csv-table:: Supported ATen operators - :file: ../build/auto_gen_aten_op_list.csv - :widths: 30, 70 +This page lists the TorchScript operators that are supported/unsupported by ONNX export. + +Supported operators +------------------- + +.. csv-table:: ONNX support for TorchScript operators + :file: ../build/onnx/auto_gen_supported_op_list.csv + :widths: 70, 30 + :header-rows: 1 + + +Unsupported operators +--------------------- + +Operators that are not yet supported + +.. csv-table:: Unsupported operators + :file: ../build/onnx/auto_gen_unsupported_op_list.csv + :widths: 70, 30 :header-rows: 1 diff --git a/docs/source/scripts/onnx/build_onnx_supported_aten_op_csv_table.py b/docs/source/scripts/onnx/build_onnx_supported_aten_op_csv_table.py index 31bc2e8c848..e4b4eb11fdf 100644 --- a/docs/source/scripts/onnx/build_onnx_supported_aten_op_csv_table.py +++ b/docs/source/scripts/onnx/build_onnx_supported_aten_op_csv_table.py @@ -8,18 +8,59 @@ import os from torch.onnx import _onnx_supported_ops # Constants -BUILD_DIR = "build" -AUTO_GEN_ATEN_OPS_CSV_FILE = "auto_gen_aten_op_list.csv" +BUILD_DIR = "build/onnx" +SUPPORTED_OPS_CSV_FILE = "auto_gen_supported_op_list.csv" +UNSUPPORTED_OPS_CSV_FILE = "auto_gen_unsupported_op_list.csv" + + +def _sort_key(namespaced_opname): + return tuple(reversed(namespaced_opname.split("::"))) + + +def _get_op_lists(): + all_schemas = _onnx_supported_ops.all_forward_schemas() + symbolic_schemas = _onnx_supported_ops.all_symbolics_schemas() + supported_result = set() + not_supported_result = set() + for opname in all_schemas: + if opname.endswith("_"): + opname = opname[:-1] + if opname in symbolic_schemas: + # Supported op + opsets = symbolic_schemas[opname].opsets + supported_result.add( + ( + opname, + f"Since opset {opsets[0]}", + ) + ) + else: + # Unsupported op + not_supported_result.add( + ( + opname, + "Not yet supported", + ) + ) + return ( + sorted(supported_result, key=lambda x: _sort_key(x[0])), + sorted(not_supported_result), + ) def main(): os.makedirs(BUILD_DIR, exist_ok=True) - aten_list = _onnx_supported_ops.onnx_supported_ops() + supported, unsupported = _get_op_lists() - with open(os.path.join(BUILD_DIR, AUTO_GEN_ATEN_OPS_CSV_FILE), "w") as f: + with open(os.path.join(BUILD_DIR, SUPPORTED_OPS_CSV_FILE), "w") as f: f.write("Operator,opset_version(s)\n") - for name, opset_version in aten_list: + for name, opset_version in supported: + f.write(f'"``{name}``","{opset_version}"\n') + + with open(os.path.join(BUILD_DIR, UNSUPPORTED_OPS_CSV_FILE), "w") as f: + f.write("Operator,opset_version(s)\n") + for name, opset_version in unsupported: f.write(f'"``{name}``","{opset_version}"\n') diff --git a/torch/onnx/_onnx_supported_ops.py b/torch/onnx/_onnx_supported_ops.py index b5a190d8546..2611b0d81e9 100644 --- a/torch/onnx/_onnx_supported_ops.py +++ b/torch/onnx/_onnx_supported_ops.py @@ -24,13 +24,15 @@ class _TorchSchema: self.opsets = [] def __str__(self) -> str: - s = f"{self.name}.{self.overload_name}(" - s += ", ".join(self.arguments) - s += ") -> (" - s += ", ".join(self.returns) - s += ")" - s += " in opsets " - s += ", ".join(str(opset) for opset in self.opsets) + s = ( + f"{self.name}.{self.overload_name}(" + + ", ".join(self.arguments) + + ") -> (" + + ", ".join(self.returns) + + ")" + + " in opsets " + + ", ".join(str(opset) for opset in self.opsets) + ) return s def __hash__(self): @@ -50,14 +52,6 @@ class _TorchSchema: return "backward" in self.name -def _all_aten_forward_schemas(): - """Creates a list of _TorchSchema for all aten schemas.""" - torch_schemas = [_TorchSchema(s) for s in _C._jit_get_all_schemas()] - torch_schemas = sorted(torch_schemas, key=lambda x: x.name) - aten_schemas = [s for s in torch_schemas if s.is_aten() and not s.is_backward()] - return aten_schemas - - def _symbolic_argument_count(func): params = [] signature = inspect.signature(func) @@ -72,7 +66,14 @@ def _symbolic_argument_count(func): return params -def _all_symbolics_schemas() -> Dict[str, _TorchSchema]: +def all_forward_schemas() -> Dict[str, _TorchSchema]: + """Returns schemas for all TorchScript forward ops.""" + torch_schemas = [_TorchSchema(s) for s in _C._jit_get_all_schemas()] + return {schema.name: schema for schema in torch_schemas if not schema.is_backward()} + + +def all_symbolics_schemas() -> Dict[str, _TorchSchema]: + """Returns schemas for all onnx supported ops.""" symbolics_schemas = {} for name in registration.registry.all_functions(): @@ -94,19 +95,3 @@ def _all_symbolics_schemas() -> Dict[str, _TorchSchema]: symbolics_schemas[name] = symbolics_schema return symbolics_schemas - - -def onnx_supported_ops(): - aten_schemas = _all_aten_forward_schemas() - symbolic_schemas = _all_symbolics_schemas() - torch_schemas = set(symbolic_schemas.values()) - supported_ops = [] - onnx_supported = [] - for schema in aten_schemas: - if schema in torch_schemas: - opname = schema.name - opsets = symbolic_schemas[opname].opsets - if schema not in supported_ops: - supported_ops.append(symbolic_schemas[opname]) - onnx_supported.append((opname, " ".join(str(o) for o in opsets))) - return sorted(onnx_supported, key=lambda x: x[0])