mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
[ONNX] Supporting different opset versions for torchlib registry (#149901)
- Allows opset_version to determine which onnx decomposition to choose - Adds a cleanup function to modify the registry after it is built Pull Request resolved: https://github.com/pytorch/pytorch/pull/149901 Approved by: https://github.com/justinchuby, https://github.com/titaiwangms
This commit is contained in:
parent
97a5e5c6b3
commit
1a56609e75
|
|
@ -246,6 +246,31 @@ class TestExportAPIDynamo(common_utils.TestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_upgraded_torchlib_impl(self):
|
||||||
|
class GeluModel(torch.nn.Module):
|
||||||
|
def forward(self, input):
|
||||||
|
# Use GELU activation function
|
||||||
|
return torch.nn.functional.gelu(input, approximate="tanh")
|
||||||
|
|
||||||
|
input = torch.randn(1, 3, 4, 4)
|
||||||
|
onnx_program_op18 = torch.onnx.export(
|
||||||
|
GeluModel(),
|
||||||
|
input,
|
||||||
|
dynamo=True,
|
||||||
|
)
|
||||||
|
all_nodes_op18 = [n.op_type for n in onnx_program_op18.model.graph]
|
||||||
|
self.assertIn("Tanh", all_nodes_op18)
|
||||||
|
self.assertNotIn("Gelu", all_nodes_op18)
|
||||||
|
|
||||||
|
onnx_program_op20 = torch.onnx.export(
|
||||||
|
GeluModel(),
|
||||||
|
input,
|
||||||
|
opset_version=20,
|
||||||
|
dynamo=True,
|
||||||
|
)
|
||||||
|
all_nodes_op20 = [n.op_type for n in onnx_program_op20.model.graph]
|
||||||
|
self.assertIn("Gelu", all_nodes_op20)
|
||||||
|
|
||||||
def test_refine_dynamic_shapes_with_onnx_export(self):
|
def test_refine_dynamic_shapes_with_onnx_export(self):
|
||||||
# NOTE: From test/export/test_export.py
|
# NOTE: From test/export/test_export.py
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -52,6 +52,7 @@ FLOAT_TYPES = (
|
||||||
torch.float64,
|
torch.float64,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
TEST_OPSET_VERSION = 18
|
TEST_OPSET_VERSION = 18
|
||||||
IS_MACOS = sys.platform.startswith("darwin")
|
IS_MACOS = sys.platform.startswith("darwin")
|
||||||
IS_WINDOWS = os.name == "nt"
|
IS_WINDOWS = os.name == "nt"
|
||||||
|
|
@ -487,6 +488,7 @@ def dtype_op_schema_compatible(dtype: torch.dtype, schema: onnx.defs.OpSchema) -
|
||||||
def graph_executor(
|
def graph_executor(
|
||||||
test_name: str,
|
test_name: str,
|
||||||
outputs: Sequence[Any],
|
outputs: Sequence[Any],
|
||||||
|
opset_version: int = TEST_OPSET_VERSION,
|
||||||
) -> Callable[[Callable[..., Any], tuple[Any], dict[str, Any]], None]:
|
) -> Callable[[Callable[..., Any], tuple[Any], dict[str, Any]], None]:
|
||||||
"""Eagerly executes a function."""
|
"""Eagerly executes a function."""
|
||||||
|
|
||||||
|
|
@ -500,10 +502,10 @@ def graph_executor(
|
||||||
(),
|
(),
|
||||||
(),
|
(),
|
||||||
nodes=(),
|
nodes=(),
|
||||||
opset_imports={"": 18, "pkg.torch.onnx": 1},
|
opset_imports={"": opset_version, "pkg.torch.onnx": 1},
|
||||||
name="main_graph",
|
name="main_graph",
|
||||||
)
|
)
|
||||||
opset = onnxscript.opset18
|
opset = onnxscript.values.Opset("", opset_version)
|
||||||
tracer = _building.OpRecorder(opset, {})
|
tracer = _building.OpRecorder(opset, {})
|
||||||
ort_inputs = {}
|
ort_inputs = {}
|
||||||
onnxscript_args: list[Any] = []
|
onnxscript_args: list[Any] = []
|
||||||
|
|
@ -590,7 +592,7 @@ def graph_executor(
|
||||||
proto = onnxscript_function.to_function_proto()
|
proto = onnxscript_function.to_function_proto()
|
||||||
ir_function = ir.serde.deserialize_function(proto)
|
ir_function = ir.serde.deserialize_function(proto)
|
||||||
onnx_model.functions[identifier] = ir_function
|
onnx_model.functions[identifier] = ir_function
|
||||||
_ir_passes.add_torchlib_common_imports(onnx_model)
|
_ir_passes.add_torchlib_common_imports(onnx_model, opset_version=opset_version)
|
||||||
_ir_passes.add_opset_imports(onnx_model)
|
_ir_passes.add_opset_imports(onnx_model)
|
||||||
# Make sure the model is valid
|
# Make sure the model is valid
|
||||||
model_proto = ir.to_proto(onnx_model)
|
model_proto = ir.to_proto(onnx_model)
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,7 @@ import numpy as np
|
||||||
import ops_test_common
|
import ops_test_common
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.onnx._internal.exporter._torchlib.ops import core as core_ops
|
from torch.onnx._internal.exporter._torchlib.ops import core as core_ops, nn as nn_ops
|
||||||
from torch.testing._internal import common_methods_invocations
|
from torch.testing._internal import common_methods_invocations
|
||||||
from torch.testing._internal.opinfo import definitions as opinfo_definitions
|
from torch.testing._internal.opinfo import definitions as opinfo_definitions
|
||||||
|
|
||||||
|
|
@ -78,6 +78,12 @@ class TorchLibOpInfo:
|
||||||
compare_shape_only_for_output: tuple[int, ...] = ()
|
compare_shape_only_for_output: tuple[int, ...] = ()
|
||||||
# Whether the function is designed for complex inputs
|
# Whether the function is designed for complex inputs
|
||||||
complex: bool = False
|
complex: bool = False
|
||||||
|
# The ONNX opset version in which the function was introduced.
|
||||||
|
# Its specifies the minimum ONNX opset version required to use the function.
|
||||||
|
# It ensures that the function is only used when the target ONNX opset version
|
||||||
|
# is compatible. For example, if `opset_introduced=20`, the function will only
|
||||||
|
# be used when exporting to ONNX models targeting opset version 20 or higher.
|
||||||
|
opset_introduced: int = 18
|
||||||
# The acceptable tolerance of the inference result difference between PyTorch and ORT.
|
# The acceptable tolerance of the inference result difference between PyTorch and ORT.
|
||||||
# Format: {dtype: (rtol, atol)}.
|
# Format: {dtype: (rtol, atol)}.
|
||||||
# For example: {torch.float16: (1e-3, 1e-3)}
|
# For example: {torch.float16: (1e-3, 1e-3)}
|
||||||
|
|
@ -447,8 +453,10 @@ TESTED_TORCHLIB_OPS: tuple[TorchLibOpInfo, ...] = (
|
||||||
TorchLibOpInfo("abs", core_ops.aten_abs_complex, complex=True),
|
TorchLibOpInfo("abs", core_ops.aten_abs_complex, complex=True),
|
||||||
TorchLibOpInfo("add", core_ops.aten_add, tolerance={torch.float16: (1e-3, 1e-3)}),
|
TorchLibOpInfo("add", core_ops.aten_add, tolerance={torch.float16: (1e-3, 1e-3)}),
|
||||||
TorchLibOpInfo("add", core_ops.aten_add_complex, complex=True),
|
TorchLibOpInfo("add", core_ops.aten_add_complex, complex=True),
|
||||||
|
TorchLibOpInfo("gelu_op20", nn_ops.aten_gelu_opset20, opset_introduced=20),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim", "all_dims"))
|
ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim", "all_dims"))
|
||||||
ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim", "any_dims"))
|
ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim", "any_dims"))
|
||||||
ops_test_common.duplicate_opinfo(
|
ops_test_common.duplicate_opinfo(
|
||||||
|
|
@ -500,6 +508,7 @@ ops_test_common.duplicate_opinfo(
|
||||||
"nn.functional.replication_pad3d",
|
"nn.functional.replication_pad3d",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
ops_test_common.duplicate_opinfo(OPS_DB, "nn.functional.gelu", ("gelu_op20",))
|
||||||
ops_test_common.duplicate_opinfo(
|
ops_test_common.duplicate_opinfo(
|
||||||
OPS_DB,
|
OPS_DB,
|
||||||
"nn.functional.scaled_dot_product_attention",
|
"nn.functional.scaled_dot_product_attention",
|
||||||
|
|
|
||||||
|
|
@ -220,7 +220,9 @@ def run_test_output_match(
|
||||||
|
|
||||||
test_name = test_suite.id()
|
test_name = test_suite.id()
|
||||||
function_output, model_proto = function_executor(
|
function_output, model_proto = function_executor(
|
||||||
test_name, reference_torch_outputs
|
test_name,
|
||||||
|
reference_torch_outputs,
|
||||||
|
opset_version=torchlib_op_info.opset_introduced,
|
||||||
)(onnx_function, input_onnx, kwargs_onnx)
|
)(onnx_function, input_onnx, kwargs_onnx)
|
||||||
# Finally we re-flatten everything
|
# Finally we re-flatten everything
|
||||||
# TODO: add pytree structure comparison.
|
# TODO: add pytree structure comparison.
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,7 @@ def export_compat(
|
||||||
verbose: bool | None = None,
|
verbose: bool | None = None,
|
||||||
input_names: Sequence[str] | None = None,
|
input_names: Sequence[str] | None = None,
|
||||||
output_names: Sequence[str] | None = None,
|
output_names: Sequence[str] | None = None,
|
||||||
opset_version: int | None = None,
|
opset_version: int | None = _constants.TORCHLIB_OPSET,
|
||||||
custom_translation_table: dict[Callable, Callable | Sequence[Callable]]
|
custom_translation_table: dict[Callable, Callable | Sequence[Callable]]
|
||||||
| None = None,
|
| None = None,
|
||||||
dynamic_axes: Mapping[str, Mapping[int, str]]
|
dynamic_axes: Mapping[str, Mapping[int, str]]
|
||||||
|
|
@ -105,8 +105,7 @@ def export_compat(
|
||||||
dynamic_shapes_with_export_dim, need_axis_mapping = (
|
dynamic_shapes_with_export_dim, need_axis_mapping = (
|
||||||
_dynamic_shapes.convert_str_to_export_dim(dynamic_shapes)
|
_dynamic_shapes.convert_str_to_export_dim(dynamic_shapes)
|
||||||
)
|
)
|
||||||
|
registry = _registration.ONNXRegistry().from_torchlib(opset_version=opset_version)
|
||||||
registry = _registration.ONNXRegistry.from_torchlib()
|
|
||||||
if custom_translation_table is not None:
|
if custom_translation_table is not None:
|
||||||
for torch_op, onnx_ops in custom_translation_table.items():
|
for torch_op, onnx_ops in custom_translation_table.items():
|
||||||
# TODO(justinchuby): Support complex inputs with annotations
|
# TODO(justinchuby): Support complex inputs with annotations
|
||||||
|
|
|
||||||
|
|
@ -90,7 +90,9 @@ def rename_axis(model: ir.Model, rename_mapping: dict[str, str]) -> None:
|
||||||
value.shape = ir.Shape(new_shape)
|
value.shape = ir.Shape(new_shape)
|
||||||
|
|
||||||
|
|
||||||
def add_torchlib_common_imports(model: ir.Model) -> None:
|
def add_torchlib_common_imports(
|
||||||
|
model: ir.Model, opset_version: int = _constants.TORCHLIB_OPSET
|
||||||
|
) -> None:
|
||||||
"""Hack to add torchlib common imports to the model."""
|
"""Hack to add torchlib common imports to the model."""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -99,9 +101,11 @@ def add_torchlib_common_imports(model: ir.Model) -> None:
|
||||||
|
|
||||||
model.opset_imports["pkg.onnxscript.torch_lib.common"] = 1
|
model.opset_imports["pkg.onnxscript.torch_lib.common"] = 1
|
||||||
rank_func = ir.serde.deserialize_function(common_ops.Rank.to_function_proto())
|
rank_func = ir.serde.deserialize_function(common_ops.Rank.to_function_proto())
|
||||||
|
rank_func.opset_imports[""] = opset_version
|
||||||
is_scalar_func = ir.serde.deserialize_function(
|
is_scalar_func = ir.serde.deserialize_function(
|
||||||
common_ops.IsScalar.to_function_proto()
|
common_ops.IsScalar.to_function_proto()
|
||||||
)
|
)
|
||||||
|
is_scalar_func.opset_imports[""] = opset_version
|
||||||
model.functions[rank_func.identifier()] = rank_func
|
model.functions[rank_func.identifier()] = rank_func
|
||||||
model.functions[is_scalar_func.identifier()] = is_scalar_func
|
model.functions[is_scalar_func.identifier()] = is_scalar_func
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
|
||||||
|
|
@ -42,6 +42,9 @@ class OnnxDecompMeta:
|
||||||
signature: The ONNX signature of the function. When None, the signature is inferred.
|
signature: The ONNX signature of the function. When None, the signature is inferred.
|
||||||
is_custom: Whether the function is a custom function.
|
is_custom: Whether the function is a custom function.
|
||||||
is_complex: Whether the function is a function that handles complex valued inputs.
|
is_complex: Whether the function is a function that handles complex valued inputs.
|
||||||
|
opset_introduced:
|
||||||
|
The ONNX opset version in which the function was introduced.
|
||||||
|
Its specifies the minimum ONNX opset version required to use the function.
|
||||||
device: The device the function is registered to. If None, it is registered to all devices.
|
device: The device the function is registered to. If None, it is registered to all devices.
|
||||||
skip_signature_inference: Whether to skip signature inference for the function.
|
skip_signature_inference: Whether to skip signature inference for the function.
|
||||||
"""
|
"""
|
||||||
|
|
@ -51,6 +54,7 @@ class OnnxDecompMeta:
|
||||||
signature: _schemas.OpSignature | None
|
signature: _schemas.OpSignature | None
|
||||||
is_custom: bool = False
|
is_custom: bool = False
|
||||||
is_complex: bool = False
|
is_complex: bool = False
|
||||||
|
opset_introduced: int = 18
|
||||||
device: Literal["cuda", "cpu"] | str | None = None # noqa: PYI051
|
device: Literal["cuda", "cpu"] | str | None = None # noqa: PYI051
|
||||||
skip_signature_inference: bool = False
|
skip_signature_inference: bool = False
|
||||||
|
|
||||||
|
|
@ -150,13 +154,14 @@ class ONNXRegistry:
|
||||||
return self._opset_version
|
return self._opset_version
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_torchlib(cls) -> ONNXRegistry:
|
def from_torchlib(cls, opset_version=_constants.TORCHLIB_OPSET) -> ONNXRegistry:
|
||||||
"""Populates the registry with ATen functions from torchlib.
|
"""Populates the registry with ATen functions from torchlib.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
torchlib_registry: The torchlib registry to use for populating the registry.
|
torchlib_registry: The torchlib registry to use for populating the registry.
|
||||||
"""
|
"""
|
||||||
registry = cls()
|
registry = cls()
|
||||||
|
registry._opset_version = opset_version
|
||||||
for meta in _torchlib_registry.get_torchlib_ops():
|
for meta in _torchlib_registry.get_torchlib_ops():
|
||||||
registry._register(meta.fx_target, meta)
|
registry._register(meta.fx_target, meta)
|
||||||
|
|
||||||
|
|
@ -185,6 +190,7 @@ class ONNXRegistry:
|
||||||
logger.exception("Failed to register '%s'. Skipped", qualified_name)
|
logger.exception("Failed to register '%s'. Skipped", qualified_name)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
registry._cleanup_registry_based_on_opset_version()
|
||||||
return registry
|
return registry
|
||||||
|
|
||||||
def _register(
|
def _register(
|
||||||
|
|
@ -274,5 +280,24 @@ class ONNXRegistry:
|
||||||
"""
|
"""
|
||||||
return bool(self.get_decomps(target))
|
return bool(self.get_decomps(target))
|
||||||
|
|
||||||
|
def _cleanup_registry_based_on_opset_version(self) -> None:
|
||||||
|
"""Pick the implementation with the highest opset version valid until the current opset version."""
|
||||||
|
cleaned_functions = {}
|
||||||
|
for target_or_name, decomps in self.functions.items():
|
||||||
|
# Filter decompositions to only include those with opset_introduced <= opset_version
|
||||||
|
decomps = [d for d in decomps if d.opset_introduced <= self.opset_version]
|
||||||
|
|
||||||
|
# Keep only the decomposition with the highest opset_introduced
|
||||||
|
if decomps:
|
||||||
|
# Find the maximum opset_introduced
|
||||||
|
max_opset = max(d.opset_introduced for d in decomps)
|
||||||
|
|
||||||
|
# Keep all decompositions with the maximum opset_introduced
|
||||||
|
cleaned_functions[target_or_name] = [
|
||||||
|
d for d in decomps if d.opset_introduced == max_opset
|
||||||
|
]
|
||||||
|
|
||||||
|
self.functions = cleaned_functions
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"{self.__class__.__name__}(functions={self.functions})"
|
return f"{self.__class__.__name__}(functions={self.functions})"
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@ def onnx_impl(
|
||||||
*,
|
*,
|
||||||
trace_only: bool = False,
|
trace_only: bool = False,
|
||||||
complex: bool = False,
|
complex: bool = False,
|
||||||
|
opset_introduced: int = 18,
|
||||||
no_compile: bool = False,
|
no_compile: bool = False,
|
||||||
private: bool = False,
|
private: bool = False,
|
||||||
) -> Callable[[_T], _T]:
|
) -> Callable[[_T], _T]:
|
||||||
|
|
@ -74,6 +75,7 @@ def onnx_impl(
|
||||||
fx_target=t,
|
fx_target=t,
|
||||||
signature=None,
|
signature=None,
|
||||||
is_complex=complex,
|
is_complex=complex,
|
||||||
|
opset_introduced=opset_introduced,
|
||||||
skip_signature_inference=no_compile,
|
skip_signature_inference=no_compile,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["core", "hop", "symbolic"]
|
__all__ = ["core", "hop", "nn", "symbolic"]
|
||||||
|
|
||||||
from torch.onnx._internal.exporter._torchlib.ops import core, hop, symbolic
|
from torch.onnx._internal.exporter._torchlib.ops import core, hop, nn, symbolic
|
||||||
|
|
|
||||||
26
torch/onnx/_internal/exporter/_torchlib/ops/nn.py
Normal file
26
torch/onnx/_internal/exporter/_torchlib/ops/nn.py
Normal file
|
|
@ -0,0 +1,26 @@
|
||||||
|
"""torch.ops.aten operators under the `core` module."""
|
||||||
|
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index"
|
||||||
|
# ruff: noqa: TCH001,TCH002
|
||||||
|
# flake8: noqa
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
from onnxscript.onnx_opset import opset20 as op20
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.onnx._internal.exporter._torchlib._tensor_typing import TReal
|
||||||
|
from torch.onnx._internal.exporter._torchlib._torchlib_registry import onnx_impl
|
||||||
|
|
||||||
|
|
||||||
|
aten = torch.ops.aten
|
||||||
|
|
||||||
|
|
||||||
|
@onnx_impl(aten.gelu.default, trace_only=True, opset_introduced=20)
|
||||||
|
def aten_gelu_opset20(
|
||||||
|
self: TReal,
|
||||||
|
approximate: str = "none",
|
||||||
|
) -> TReal:
|
||||||
|
"""gelu(Tensor self, *, bool approximate=False) -> Tensor"""
|
||||||
|
return op20.Gelu(self, approximate=approximate)
|
||||||
Loading…
Reference in New Issue
Block a user