diff --git a/test/onnx/exporter/test_api.py b/test/onnx/exporter/test_api.py index a4dc1c97772..3ebf00eccec 100644 --- a/test/onnx/exporter/test_api.py +++ b/test/onnx/exporter/test_api.py @@ -246,6 +246,31 @@ class TestExportAPIDynamo(common_utils.TestCase): ) ) + def test_upgraded_torchlib_impl(self): + class GeluModel(torch.nn.Module): + def forward(self, input): + # Use GELU activation function + return torch.nn.functional.gelu(input, approximate="tanh") + + input = torch.randn(1, 3, 4, 4) + onnx_program_op18 = torch.onnx.export( + GeluModel(), + input, + dynamo=True, + ) + all_nodes_op18 = [n.op_type for n in onnx_program_op18.model.graph] + self.assertIn("Tanh", all_nodes_op18) + self.assertNotIn("Gelu", all_nodes_op18) + + onnx_program_op20 = torch.onnx.export( + GeluModel(), + input, + opset_version=20, + dynamo=True, + ) + all_nodes_op20 = [n.op_type for n in onnx_program_op20.model.graph] + self.assertIn("Gelu", all_nodes_op20) + def test_refine_dynamic_shapes_with_onnx_export(self): # NOTE: From test/export/test_export.py diff --git a/test/onnx/torchlib/ops_test_common.py b/test/onnx/torchlib/ops_test_common.py index 73c00de388f..884b66d4e02 100644 --- a/test/onnx/torchlib/ops_test_common.py +++ b/test/onnx/torchlib/ops_test_common.py @@ -52,6 +52,7 @@ FLOAT_TYPES = ( torch.float64, ) + TEST_OPSET_VERSION = 18 IS_MACOS = sys.platform.startswith("darwin") IS_WINDOWS = os.name == "nt" @@ -487,6 +488,7 @@ def dtype_op_schema_compatible(dtype: torch.dtype, schema: onnx.defs.OpSchema) - def graph_executor( test_name: str, outputs: Sequence[Any], + opset_version: int = TEST_OPSET_VERSION, ) -> Callable[[Callable[..., Any], tuple[Any], dict[str, Any]], None]: """Eagerly executes a function.""" @@ -500,10 +502,10 @@ def graph_executor( (), (), nodes=(), - opset_imports={"": 18, "pkg.torch.onnx": 1}, + opset_imports={"": opset_version, "pkg.torch.onnx": 1}, name="main_graph", ) - opset = onnxscript.opset18 + opset = onnxscript.values.Opset("", opset_version) tracer = _building.OpRecorder(opset, {}) ort_inputs = {} onnxscript_args: list[Any] = [] @@ -590,7 +592,7 @@ def graph_executor( proto = onnxscript_function.to_function_proto() ir_function = ir.serde.deserialize_function(proto) onnx_model.functions[identifier] = ir_function - _ir_passes.add_torchlib_common_imports(onnx_model) + _ir_passes.add_torchlib_common_imports(onnx_model, opset_version=opset_version) _ir_passes.add_opset_imports(onnx_model) # Make sure the model is valid model_proto = ir.to_proto(onnx_model) diff --git a/test/onnx/torchlib/ops_test_data.py b/test/onnx/torchlib/ops_test_data.py index b255f07640b..a69d7a4bec1 100644 --- a/test/onnx/torchlib/ops_test_data.py +++ b/test/onnx/torchlib/ops_test_data.py @@ -46,7 +46,7 @@ import numpy as np import ops_test_common import torch -from torch.onnx._internal.exporter._torchlib.ops import core as core_ops +from torch.onnx._internal.exporter._torchlib.ops import core as core_ops, nn as nn_ops from torch.testing._internal import common_methods_invocations from torch.testing._internal.opinfo import definitions as opinfo_definitions @@ -78,6 +78,12 @@ class TorchLibOpInfo: compare_shape_only_for_output: tuple[int, ...] = () # Whether the function is designed for complex inputs complex: bool = False + # The ONNX opset version in which the function was introduced. + # Its specifies the minimum ONNX opset version required to use the function. + # It ensures that the function is only used when the target ONNX opset version + # is compatible. For example, if `opset_introduced=20`, the function will only + # be used when exporting to ONNX models targeting opset version 20 or higher. + opset_introduced: int = 18 # The acceptable tolerance of the inference result difference between PyTorch and ORT. # Format: {dtype: (rtol, atol)}. # For example: {torch.float16: (1e-3, 1e-3)} @@ -447,8 +453,10 @@ TESTED_TORCHLIB_OPS: tuple[TorchLibOpInfo, ...] = ( TorchLibOpInfo("abs", core_ops.aten_abs_complex, complex=True), TorchLibOpInfo("add", core_ops.aten_add, tolerance={torch.float16: (1e-3, 1e-3)}), TorchLibOpInfo("add", core_ops.aten_add_complex, complex=True), + TorchLibOpInfo("gelu_op20", nn_ops.aten_gelu_opset20, opset_introduced=20), ) + ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim", "all_dims")) ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim", "any_dims")) ops_test_common.duplicate_opinfo( @@ -500,6 +508,7 @@ ops_test_common.duplicate_opinfo( "nn.functional.replication_pad3d", ), ) +ops_test_common.duplicate_opinfo(OPS_DB, "nn.functional.gelu", ("gelu_op20",)) ops_test_common.duplicate_opinfo( OPS_DB, "nn.functional.scaled_dot_product_attention", diff --git a/test/onnx/torchlib/test_ops.py b/test/onnx/torchlib/test_ops.py index 74cbeeca313..a7a52698cd2 100644 --- a/test/onnx/torchlib/test_ops.py +++ b/test/onnx/torchlib/test_ops.py @@ -220,7 +220,9 @@ def run_test_output_match( test_name = test_suite.id() function_output, model_proto = function_executor( - test_name, reference_torch_outputs + test_name, + reference_torch_outputs, + opset_version=torchlib_op_info.opset_introduced, )(onnx_function, input_onnx, kwargs_onnx) # Finally we re-flatten everything # TODO: add pytree structure comparison. diff --git a/torch/onnx/_internal/exporter/_compat.py b/torch/onnx/_internal/exporter/_compat.py index a38203d2314..b570b20bd02 100644 --- a/torch/onnx/_internal/exporter/_compat.py +++ b/torch/onnx/_internal/exporter/_compat.py @@ -50,7 +50,7 @@ def export_compat( verbose: bool | None = None, input_names: Sequence[str] | None = None, output_names: Sequence[str] | None = None, - opset_version: int | None = None, + opset_version: int | None = _constants.TORCHLIB_OPSET, custom_translation_table: dict[Callable, Callable | Sequence[Callable]] | None = None, dynamic_axes: Mapping[str, Mapping[int, str]] @@ -105,8 +105,7 @@ def export_compat( dynamic_shapes_with_export_dim, need_axis_mapping = ( _dynamic_shapes.convert_str_to_export_dim(dynamic_shapes) ) - - registry = _registration.ONNXRegistry.from_torchlib() + registry = _registration.ONNXRegistry().from_torchlib(opset_version=opset_version) if custom_translation_table is not None: for torch_op, onnx_ops in custom_translation_table.items(): # TODO(justinchuby): Support complex inputs with annotations diff --git a/torch/onnx/_internal/exporter/_ir_passes.py b/torch/onnx/_internal/exporter/_ir_passes.py index 804e93acbd6..8a715e24559 100644 --- a/torch/onnx/_internal/exporter/_ir_passes.py +++ b/torch/onnx/_internal/exporter/_ir_passes.py @@ -90,7 +90,9 @@ def rename_axis(model: ir.Model, rename_mapping: dict[str, str]) -> None: value.shape = ir.Shape(new_shape) -def add_torchlib_common_imports(model: ir.Model) -> None: +def add_torchlib_common_imports( + model: ir.Model, opset_version: int = _constants.TORCHLIB_OPSET +) -> None: """Hack to add torchlib common imports to the model.""" try: @@ -99,9 +101,11 @@ def add_torchlib_common_imports(model: ir.Model) -> None: model.opset_imports["pkg.onnxscript.torch_lib.common"] = 1 rank_func = ir.serde.deserialize_function(common_ops.Rank.to_function_proto()) + rank_func.opset_imports[""] = opset_version is_scalar_func = ir.serde.deserialize_function( common_ops.IsScalar.to_function_proto() ) + is_scalar_func.opset_imports[""] = opset_version model.functions[rank_func.identifier()] = rank_func model.functions[is_scalar_func.identifier()] = is_scalar_func except Exception: diff --git a/torch/onnx/_internal/exporter/_registration.py b/torch/onnx/_internal/exporter/_registration.py index ac81d2301cc..fefc8022d7e 100644 --- a/torch/onnx/_internal/exporter/_registration.py +++ b/torch/onnx/_internal/exporter/_registration.py @@ -42,6 +42,9 @@ class OnnxDecompMeta: signature: The ONNX signature of the function. When None, the signature is inferred. is_custom: Whether the function is a custom function. is_complex: Whether the function is a function that handles complex valued inputs. + opset_introduced: + The ONNX opset version in which the function was introduced. + Its specifies the minimum ONNX opset version required to use the function. device: The device the function is registered to. If None, it is registered to all devices. skip_signature_inference: Whether to skip signature inference for the function. """ @@ -51,6 +54,7 @@ class OnnxDecompMeta: signature: _schemas.OpSignature | None is_custom: bool = False is_complex: bool = False + opset_introduced: int = 18 device: Literal["cuda", "cpu"] | str | None = None # noqa: PYI051 skip_signature_inference: bool = False @@ -150,13 +154,14 @@ class ONNXRegistry: return self._opset_version @classmethod - def from_torchlib(cls) -> ONNXRegistry: + def from_torchlib(cls, opset_version=_constants.TORCHLIB_OPSET) -> ONNXRegistry: """Populates the registry with ATen functions from torchlib. Args: torchlib_registry: The torchlib registry to use for populating the registry. """ registry = cls() + registry._opset_version = opset_version for meta in _torchlib_registry.get_torchlib_ops(): registry._register(meta.fx_target, meta) @@ -185,6 +190,7 @@ class ONNXRegistry: logger.exception("Failed to register '%s'. Skipped", qualified_name) continue + registry._cleanup_registry_based_on_opset_version() return registry def _register( @@ -274,5 +280,24 @@ class ONNXRegistry: """ return bool(self.get_decomps(target)) + def _cleanup_registry_based_on_opset_version(self) -> None: + """Pick the implementation with the highest opset version valid until the current opset version.""" + cleaned_functions = {} + for target_or_name, decomps in self.functions.items(): + # Filter decompositions to only include those with opset_introduced <= opset_version + decomps = [d for d in decomps if d.opset_introduced <= self.opset_version] + + # Keep only the decomposition with the highest opset_introduced + if decomps: + # Find the maximum opset_introduced + max_opset = max(d.opset_introduced for d in decomps) + + # Keep all decompositions with the maximum opset_introduced + cleaned_functions[target_or_name] = [ + d for d in decomps if d.opset_introduced == max_opset + ] + + self.functions = cleaned_functions + def __repr__(self) -> str: return f"{self.__class__.__name__}(functions={self.functions})" diff --git a/torch/onnx/_internal/exporter/_torchlib/_torchlib_registry.py b/torch/onnx/_internal/exporter/_torchlib/_torchlib_registry.py index e71bdeb0c68..039eeb3e2fc 100644 --- a/torch/onnx/_internal/exporter/_torchlib/_torchlib_registry.py +++ b/torch/onnx/_internal/exporter/_torchlib/_torchlib_registry.py @@ -30,6 +30,7 @@ def onnx_impl( *, trace_only: bool = False, complex: bool = False, + opset_introduced: int = 18, no_compile: bool = False, private: bool = False, ) -> Callable[[_T], _T]: @@ -74,6 +75,7 @@ def onnx_impl( fx_target=t, signature=None, is_complex=complex, + opset_introduced=opset_introduced, skip_signature_inference=no_compile, ) ) diff --git a/torch/onnx/_internal/exporter/_torchlib/ops/__init__.py b/torch/onnx/_internal/exporter/_torchlib/ops/__init__.py index d07768f252b..bff8860fcb1 100644 --- a/torch/onnx/_internal/exporter/_torchlib/ops/__init__.py +++ b/torch/onnx/_internal/exporter/_torchlib/ops/__init__.py @@ -1,6 +1,6 @@ from __future__ import annotations -__all__ = ["core", "hop", "symbolic"] +__all__ = ["core", "hop", "nn", "symbolic"] -from torch.onnx._internal.exporter._torchlib.ops import core, hop, symbolic +from torch.onnx._internal.exporter._torchlib.ops import core, hop, nn, symbolic diff --git a/torch/onnx/_internal/exporter/_torchlib/ops/nn.py b/torch/onnx/_internal/exporter/_torchlib/ops/nn.py new file mode 100644 index 00000000000..4ca21662d69 --- /dev/null +++ b/torch/onnx/_internal/exporter/_torchlib/ops/nn.py @@ -0,0 +1,26 @@ +"""torch.ops.aten operators under the `core` module.""" +# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index" +# ruff: noqa: TCH001,TCH002 +# flake8: noqa + +from __future__ import annotations + +import math + +from onnxscript.onnx_opset import opset20 as op20 + +import torch +from torch.onnx._internal.exporter._torchlib._tensor_typing import TReal +from torch.onnx._internal.exporter._torchlib._torchlib_registry import onnx_impl + + +aten = torch.ops.aten + + +@onnx_impl(aten.gelu.default, trace_only=True, opset_introduced=20) +def aten_gelu_opset20( + self: TReal, + approximate: str = "none", +) -> TReal: + """gelu(Tensor self, *, bool approximate=False) -> Tensor""" + return op20.Gelu(self, approximate=approximate)