mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
* Distinguish both TorchScript-based exporter (`torch.onnx.export`) and the TorchDynamo-based exporter (`torch.onnx.dynamo_export`) exporters * Merge ONNX diagnostics page with the exporter page * Add initial version of a quick overview on the new exporter * Updates `torch.compiler.html` with the right page for the ONNX Runtime backend for `torch.compile` * Renamed doc files to clearly identify files belonging to the legacy and newer onnx exporters Fixes #108274 https://docs-preview.pytorch.org/pytorch/pytorch/108379/index.html Pull Request resolved: https://github.com/pytorch/pytorch/pull/108379 Approved by: https://github.com/justinchuby, https://github.com/wschin, https://github.com/malfet
1301 lines
50 KiB
Python
1301 lines
50 KiB
Python
# necessary to surface onnx.ModelProto through ExportOutput:
|
|
from __future__ import annotations
|
|
|
|
import abc
|
|
|
|
import contextlib
|
|
import dataclasses
|
|
import io
|
|
import logging
|
|
import os
|
|
|
|
import warnings
|
|
from collections import defaultdict
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Dict,
|
|
Final,
|
|
List,
|
|
Mapping,
|
|
Optional,
|
|
Protocol,
|
|
runtime_checkable,
|
|
Sequence,
|
|
Set,
|
|
Tuple,
|
|
TYPE_CHECKING,
|
|
TypeVar,
|
|
Union,
|
|
)
|
|
|
|
from typing_extensions import Self
|
|
|
|
import torch
|
|
import torch._ops
|
|
import torch.utils._pytree as pytree
|
|
from torch._subclasses import fake_tensor
|
|
|
|
from torch.onnx._internal import _beartype, io_adapter
|
|
from torch.onnx._internal.diagnostics import infra
|
|
|
|
from torch.onnx._internal.fx import (
|
|
decomposition_table,
|
|
patcher as patcher,
|
|
registration,
|
|
serialization as fx_serialization,
|
|
)
|
|
|
|
# We can only import onnx from this module in a type-checking context to ensure that
|
|
# 'import torch.onnx' continues to work without having 'onnx' installed. We fully
|
|
# 'import onnx' inside of dynamo_export (by way of _assert_dependencies).
|
|
if TYPE_CHECKING:
|
|
import onnx
|
|
import onnxscript # type: ignore[import]
|
|
from onnxscript.function_libs.torch_lib import ( # type: ignore[import]
|
|
registration as torchlib_registry,
|
|
)
|
|
|
|
from torch.onnx._internal.fx import diagnostics
|
|
else:
|
|
try:
|
|
# beartype needs this import due to runtime type checking.
|
|
# This cannot be normally imported at top level due to
|
|
# https://github.com/pytorch/pytorch/issues/103764
|
|
from torch.onnx._internal.fx import diagnostics
|
|
except ImportError:
|
|
# The error will be handled elsewhere when the exporter is used.
|
|
pass
|
|
|
|
_DEFAULT_OPSET_VERSION: Final[int] = 18
|
|
"""The default ONNX opset version the exporter will use if one is not specified explicitly
|
|
through :class:`ExportOptions`. This should NEVER be accessed outside of this module! Users
|
|
should reference :attr:`ExportOptions.opset_version`."""
|
|
|
|
_PYTORCH_GITHUB_ISSUES_URL = "https://github.com/pytorch/pytorch/issues"
|
|
"""The URL to the PyTorch GitHub issues page."""
|
|
|
|
_DEFAULT_FAILED_EXPORT_SARIF_LOG_PATH = "report_dynamo_export.sarif"
|
|
"""The default path to write the SARIF log to if the export fails."""
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
DiagnosticOptions = infra.DiagnosticOptions
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ONNXFakeContext:
|
|
"""A dataclass used to store context for model export using FakeTensor.
|
|
|
|
This dataclass stores the FakeTensorMode instance used to convert
|
|
real tensors and model parameters into fake tensors. This :attr:`ONNXFakeContext.fake_mode` is
|
|
reused internally during tracing of a :class:`torch.nn.Module` into a FX :class:`GraphModule`.
|
|
"""
|
|
|
|
fake_mode: fake_tensor.FakeTensorMode
|
|
"""The fake tensor mode used for tracing model using fake tensors and parameters."""
|
|
|
|
state_dict_paths: Optional[Tuple[Union[str, io.BytesIO]]] = None
|
|
"""List of paths of files that contain the model :meth:`state_dict`"""
|
|
|
|
|
|
class OnnxRegistry:
|
|
"""Registry for ONNX functions.
|
|
|
|
The registry maintains a mapping from qualified names to symbolic functions under a
|
|
fixed opset version. It supports registering custom onnx-script functions and for
|
|
dispatcher to dispatch calls to the appropriate function.
|
|
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
"""Initializes the registry"""
|
|
|
|
# NOTE: _registry is the registry maps OpNameto a list of ONNXFunctions. It is important
|
|
# not to directly modify this variable. Instead, access to it should be done through
|
|
# the public methods: register_custom_op, get_ops, and is_registered_op.
|
|
self._registry: Dict[
|
|
registration.OpName, List[registration.ONNXFunction]
|
|
] = defaultdict(list)
|
|
# FIXME: Avoid importing onnxscript into torch
|
|
from onnxscript.function_libs.torch_lib import ( # type: ignore[import] # noqa: F401
|
|
ops, # TODO(titaiwang): get rid of this import
|
|
registration,
|
|
)
|
|
|
|
# opset_version is unused for now, since torchlib only supports opset18.
|
|
# TODO: get opset version from torchlib
|
|
self._opset_version = _DEFAULT_OPSET_VERSION
|
|
warnings.warn(
|
|
f"torch.onnx.dynamo_export only implements opset version {self._opset_version} for now. If you need to use a "
|
|
"different opset version, please register them with register_custom_op."
|
|
)
|
|
|
|
# Initialize registry from torchlib
|
|
self._initiate_registry_from_torchlib(registration.default_registry)
|
|
|
|
@property
|
|
def opset_version(self) -> int:
|
|
"""The ONNX opset version the exporter should target. Defaults to the latest
|
|
supported ONNX opset version: 18. The default version will increment over time as
|
|
ONNX continues to evolve."""
|
|
|
|
return self._opset_version
|
|
|
|
# TODO(titaiwang): subject to change if multiple opset_version is supported in torchlib
|
|
def _initiate_registry_from_torchlib(
|
|
self, torchlib_registry: torchlib_registry.Registry
|
|
):
|
|
"""Populates the registry with ATen functions from torchlib.
|
|
|
|
Args:
|
|
torchlib_registry: The torchlib registry to use for populating the registry.
|
|
"""
|
|
for aten_name, aten_overloads_func in torchlib_registry.items():
|
|
internal_name_instance = registration.OpName.from_qualified_name(aten_name)
|
|
for overload_func in aten_overloads_func.overloads:
|
|
symbolic_function = registration.ONNXFunction(
|
|
onnx_function=overload_func,
|
|
op_full_name=internal_name_instance.qualified_name(),
|
|
is_custom=False,
|
|
is_complex=False,
|
|
)
|
|
self._register(internal_name_instance, symbolic_function)
|
|
|
|
for complex_func in aten_overloads_func.complex:
|
|
symbolic_function = registration.ONNXFunction(
|
|
onnx_function=complex_func,
|
|
op_full_name=internal_name_instance.qualified_name(),
|
|
is_custom=False,
|
|
is_complex=True,
|
|
)
|
|
self._register(internal_name_instance, symbolic_function)
|
|
|
|
@_beartype.beartype
|
|
def _register(
|
|
self,
|
|
internal_qualified_name: registration.OpName,
|
|
symbolic_function: registration.ONNXFunction,
|
|
) -> None:
|
|
"""Registers a ONNXFunction to an operator.
|
|
|
|
Args:
|
|
internal_qualified_name: The qualified name of the operator to register: OpName.
|
|
symbolic_function: The ONNXFunction to register.
|
|
"""
|
|
self._registry[internal_qualified_name].append(symbolic_function)
|
|
|
|
@_beartype.beartype
|
|
def register_op(
|
|
self,
|
|
function: Union["onnxscript.OnnxFunction", "onnxscript.TracedOnnxFunction"],
|
|
namespace: str,
|
|
op_name: str,
|
|
overload: Optional[str] = None,
|
|
is_complex: bool = False,
|
|
) -> None:
|
|
"""Registers a custom operator: torch.ops.<namespace>.<op_name>.<overload>.
|
|
|
|
Args:
|
|
function: The onnx-sctip function to register.
|
|
namespace: The namespace of the operator to register.
|
|
op_name: The name of the operator to register.
|
|
overload: The overload of the operator to register. If it's default overload,
|
|
leave it to None.
|
|
is_complex: Whether the function is a function that handles complex valued inputs.
|
|
|
|
Raises:
|
|
ValueError: If the name is not in the form of 'namespace::op'.
|
|
"""
|
|
internal_name_instance = registration.OpName.from_name_parts(
|
|
namespace=namespace, op_name=op_name, overload=overload
|
|
)
|
|
symbolic_function = registration.ONNXFunction(
|
|
onnx_function=function,
|
|
op_full_name=internal_name_instance.qualified_name(),
|
|
is_custom=True,
|
|
is_complex=is_complex,
|
|
)
|
|
self._register(internal_name_instance, symbolic_function)
|
|
|
|
@_beartype.beartype
|
|
def get_op_functions(
|
|
self, namespace: str, op_name: str, overload: Optional[str] = None
|
|
) -> Optional[List[registration.ONNXFunction]]:
|
|
"""Returns a list of ONNXFunctions for the given op: torch.ops.<namespace>.<op_name>.<overload>.
|
|
|
|
The list is ordered by the time of registration. The custom operators should be
|
|
in the second half of the list.
|
|
|
|
Args:
|
|
namespace: The namespace of the operator to get.
|
|
op_name: The name of the operator to get.
|
|
overload: The overload of the operator to get. If it's default overload,
|
|
leave it to None.
|
|
Returns:
|
|
A list of ONNXFunctions corresponding to the given name, or None if
|
|
the name is not in the registry.
|
|
"""
|
|
internal_name_instance = registration.OpName.from_name_parts(
|
|
namespace=namespace, op_name=op_name, overload=overload
|
|
)
|
|
return self._registry.get(internal_name_instance)
|
|
|
|
@_beartype.beartype
|
|
def is_registered_op(
|
|
self, namespace: str, op_name: str, overload: Optional[str] = None
|
|
) -> bool:
|
|
"""Returns whether the given op is registered: torch.ops.<namespace>.<op_name>.<overload>.
|
|
|
|
Args:
|
|
namespace: The namespace of the operator to check.
|
|
op_name: The name of the operator to check.
|
|
overload: The overload of the operator to check. If it's default overload,
|
|
leave it to None.
|
|
|
|
Returns:
|
|
True if the given op is registered, otherwise False.
|
|
"""
|
|
functions = self.get_op_functions(
|
|
namespace=namespace, op_name=op_name, overload=overload
|
|
)
|
|
return functions is not None
|
|
|
|
@_beartype.beartype
|
|
def _all_registered_ops(self) -> Set[str]:
|
|
"""Returns the set of all registered function names."""
|
|
return {
|
|
op_name_class.qualified_name() for op_name_class in self._registry.keys()
|
|
}
|
|
|
|
|
|
class ExportOptions:
|
|
"""Options to influence the TorchDynamo ONNX exporter.
|
|
|
|
Attributes:
|
|
dynamic_shapes: Shape information hint for input/output tensors.
|
|
When ``None``, the exporter determines the most compatible setting.
|
|
When ``True``, all input shapes are considered dynamic.
|
|
When ``False``, all input shapes are considered static.
|
|
op_level_debug: Whether to export the model with op-level debug information
|
|
diagnostic_options: The diagnostic options for the exporter.
|
|
fake_context: The fake context used for symbolic tracing.
|
|
onnx_registry: The ONNX registry used to register ATen operators to ONNX functions.
|
|
"""
|
|
|
|
dynamic_shapes: Optional[bool] = None
|
|
"""Shape information hint for input/output tensors.
|
|
|
|
- ``None``: the exporter determines the most compatible setting.
|
|
- ``True``: all input shapes are considered dynamic.
|
|
- ``False``: all input shapes are considered static.
|
|
"""
|
|
|
|
op_level_debug: Optional[bool] = None
|
|
"""When True export the model with op-level debug running ops through ONNX Runtime."""
|
|
|
|
diagnostic_options: DiagnosticOptions
|
|
"""The diagnostic options for the exporter."""
|
|
|
|
fake_context: Optional[ONNXFakeContext] = None
|
|
"""The fake context used for symbolic tracing."""
|
|
|
|
onnx_registry: Optional[OnnxRegistry] = None
|
|
"""The ONNX registry used to register ATen operators to ONNX functions."""
|
|
|
|
@_beartype.beartype
|
|
def __init__(
|
|
self,
|
|
*,
|
|
dynamic_shapes: Optional[bool] = None,
|
|
op_level_debug: Optional[bool] = None,
|
|
fake_context: Optional[ONNXFakeContext] = None,
|
|
onnx_registry: Optional[OnnxRegistry] = None,
|
|
diagnostic_options: Optional[DiagnosticOptions] = None,
|
|
):
|
|
self.dynamic_shapes = dynamic_shapes
|
|
self.op_level_debug = op_level_debug
|
|
self.fake_context = fake_context
|
|
self.onnx_registry = onnx_registry
|
|
self.diagnostic_options = diagnostic_options or DiagnosticOptions()
|
|
|
|
|
|
class ResolvedExportOptions(ExportOptions):
|
|
"""Consolidates :class:`ExportOptions` with default values.
|
|
All unspecified options from :class:`ExportOptions` are assigned a default value.
|
|
This is an internal class and its API may be changed at any time without notice.
|
|
"""
|
|
|
|
# Public attributes MUST be redefined below without ``Optional[]`` from ``ExportOptions``
|
|
dynamic_shapes: bool
|
|
op_level_debug: bool
|
|
diagnostic_options: DiagnosticOptions
|
|
fake_context: ONNXFakeContext
|
|
onnx_registry: OnnxRegistry
|
|
|
|
# Private only attributes
|
|
decomposition_table: Dict[torch._ops.OpOverload, Callable]
|
|
"""A dictionary that maps operators to their decomposition functions."""
|
|
|
|
onnxfunction_dispatcher: torch.onnx._internal.fx.onnxfunction_dispatcher.OnnxFunctionDispatcher
|
|
"""The ONNX dispatcher used to dispatch ATen operators to ONNX functions."""
|
|
|
|
fx_tracer: FXGraphExtractor
|
|
"""The FXGraphExtractor instance used to extract the FX graph from the model."""
|
|
|
|
diagnostic_context: diagnostics.DiagnosticContext
|
|
"""The diagnostics context for the export. Responsible for recording diagnostics,
|
|
logging diagnostics, and generating the SARIF log."""
|
|
|
|
@_beartype.beartype
|
|
def __init__(
|
|
self, options: Optional[Union[ExportOptions, "ResolvedExportOptions"]]
|
|
):
|
|
if options is None:
|
|
options = ExportOptions()
|
|
if isinstance(options, ResolvedExportOptions):
|
|
self.dynamic_shapes = options.dynamic_shapes
|
|
self.op_level_debug = options.op_level_debug
|
|
self.diagnostic_options = options.diagnostic_options
|
|
self.fake_context = options.fake_context
|
|
# private
|
|
self.fx_tracer = options.fx_tracer
|
|
self.onnx_registry = options.onnx_registry
|
|
self.onnxfunction_dispatcher = options.onnxfunction_dispatcher
|
|
self.decomposition_table = options.decomposition_table
|
|
self.diagnostic_context = options.diagnostic_context
|
|
else:
|
|
T = TypeVar("T")
|
|
|
|
@_beartype.beartype
|
|
def resolve(value: Optional[T], fallback: Union[T, Callable[[], T]]) -> T:
|
|
if value is not None:
|
|
return value
|
|
if callable(fallback):
|
|
return fallback()
|
|
return fallback
|
|
|
|
self.dynamic_shapes = resolve(options.dynamic_shapes, False)
|
|
from torch.onnx._internal.fx import ( # TODO: Prevent circular dep
|
|
diagnostics,
|
|
dynamo_graph_extractor,
|
|
)
|
|
|
|
self.diagnostic_options = resolve(
|
|
options.diagnostic_options, DiagnosticOptions()
|
|
)
|
|
|
|
self.fx_tracer = dynamo_graph_extractor.DynamoExport()
|
|
|
|
self.fake_context = resolve(options.fake_context, None)
|
|
self.diagnostic_context = diagnostics.DiagnosticContext(
|
|
"torch.onnx.dynamo_export",
|
|
torch.__version__,
|
|
self.diagnostic_options,
|
|
)
|
|
|
|
self.onnx_registry = resolve(options.onnx_registry, OnnxRegistry())
|
|
self.decomposition_table = (
|
|
decomposition_table.create_onnx_friendly_decomposition_table(
|
|
self.onnx_registry
|
|
)
|
|
)
|
|
|
|
# TODO(titaiwang, bowbao): Better way to annotate `onnxscript` types in diagnostics.
|
|
from torch.onnx._internal.fx import onnxfunction_dispatcher
|
|
|
|
self.op_level_debug = resolve(options.op_level_debug, False)
|
|
self.onnxfunction_dispatcher = (
|
|
onnxfunction_dispatcher.OnnxFunctionDispatcher(
|
|
self.onnx_registry,
|
|
self.diagnostic_context,
|
|
)
|
|
)
|
|
|
|
for key in dir(options):
|
|
if not key.startswith("_"): # skip private attributes
|
|
assert hasattr(self, key), f"Unresolved option '{key}'"
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def enable_fake_mode():
|
|
"""Enable fake mode for the duration of the context.
|
|
|
|
Internally it instantiates a :class:`torch._subclasses.fake_tensor.FakeTensorMode` context manager
|
|
that converts user input and model parameters into :class:`torch._subclasses.fake_tensor.FakeTensor`.
|
|
|
|
A :class:`torch._subclasses.fake_tensor.FakeTensor`
|
|
is a :class:`torch.Tensor` with the ability to run PyTorch code without having to
|
|
actually do computation through tensors allocated on a ``meta`` device. Because
|
|
there is no actual data being allocated on the device, this API allows for
|
|
exporting large models without the actual memory footprint needed for executing it.
|
|
|
|
It is highly recommended to enable fake mode when exporting models that
|
|
are too large to fit into memory.
|
|
|
|
Returns:
|
|
A :class:`ONNXFakeContext` object that must be passed to :func:`dynamo_export`
|
|
through the :attr:`ExportOptions.fake_context` argument.
|
|
|
|
Example::
|
|
|
|
# xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
|
|
>>> import torch
|
|
>>> import torch.onnx
|
|
>>> class MyModel(torch.nn.Module): # Dummy model
|
|
... def __init__(self) -> None:
|
|
... super().__init__()
|
|
... self.linear = torch.nn.Linear(2, 2)
|
|
... def forward(self, x):
|
|
... out = self.linear(x)
|
|
... return out
|
|
>>> with torch.onnx.enable_fake_mode() as fake_context:
|
|
... my_nn_module = MyModel()
|
|
... arg1 = torch.randn(2, 2, 2) # positional input 1
|
|
>>> export_options = torch.onnx.ExportOptions(fake_context=fake_context)
|
|
>>> export_output = torch.onnx.dynamo_export(
|
|
... my_nn_module,
|
|
... arg1,
|
|
... export_options=export_options
|
|
... )
|
|
>>> # Saving model WITHOUT initializers
|
|
>>> export_output.save("my_model_without_initializers.onnx")
|
|
>>> # Saving model WITH initializers
|
|
>>> export_output.save("my_model_with_initializers.onnx", model_state_dict=MyModel().state_dict())
|
|
|
|
.. warning::
|
|
This API is experimental and is *NOT* backward-compatible.
|
|
|
|
"""
|
|
from torch._subclasses import fake_tensor
|
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
|
|
# This overrides the internal `FakeTensorMode` instance created by `torch._dynamo.export`[1].
|
|
# It is a good idea to keep them in sync (constructor args) to maintain the same default behavior
|
|
# [1] `torch/_dynamo/output_graph.py::InstructionTranslator::OutputGraph.__init__`
|
|
# Mixed fake/real tensors are only allowed when `torch.onnx.dynamo_export` is not called within `FakeTensorMode`
|
|
# This is needed because models can create new parameters during `forward(self, *args, **kwargs)` run
|
|
fake_mode = fake_tensor.FakeTensorMode(
|
|
allow_non_fake_inputs=not torch._guards.detect_fake_mode(),
|
|
shape_env=ShapeEnv(
|
|
allow_scalar_outputs=False, allow_dynamic_output_shape_ops=False
|
|
),
|
|
)
|
|
# The patcher is needed for when user calls `fake_model.load_state_dict(...)` within fake mode
|
|
patcher_context = patcher.ONNXTorchPatcher()
|
|
fake_context = ONNXFakeContext(fake_mode=fake_mode)
|
|
with fake_mode, patcher_context:
|
|
yield fake_context
|
|
fake_context.state_dict_paths = tuple(
|
|
patcher_context.paths,
|
|
) # type: ignore[assignment]
|
|
|
|
|
|
@runtime_checkable
|
|
class ExportOutputSerializer(Protocol):
|
|
"""Protocol for serializing an ONNX graph into a specific format (e.g. Protobuf).
|
|
Note that this is an advanced usage scenario."""
|
|
|
|
def serialize(
|
|
self, export_output: ExportOutput, destination: io.BufferedIOBase
|
|
) -> None:
|
|
"""Protocol method that must be implemented for serialization.
|
|
|
|
Args:
|
|
export_output: Represents the in-memory exported ONNX model
|
|
destination: A binary IO stream or pre-allocated buffer into which
|
|
the serialized model should be written.
|
|
|
|
Example:
|
|
|
|
A simple serializer that writes the exported :py:obj:`onnx.ModelProto` in Protobuf
|
|
format to ``destination``:
|
|
|
|
::
|
|
|
|
# xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
|
|
>>> import io
|
|
>>> import torch
|
|
>>> import torch.onnx
|
|
>>> class MyModel(torch.nn.Module): # Dummy model
|
|
... def __init__(self) -> None:
|
|
... super().__init__()
|
|
... self.linear = torch.nn.Linear(2, 2)
|
|
... def forward(self, x):
|
|
... out = self.linear(x)
|
|
... return out
|
|
>>> class ProtobufExportOutputSerializer:
|
|
... def serialize(
|
|
... self, export_output: torch.onnx.ExportOutput, destination: io.BufferedIOBase
|
|
... ) -> None:
|
|
... destination.write(export_output.model_proto.SerializeToString())
|
|
>>> model = MyModel()
|
|
>>> arg1 = torch.randn(2, 2, 2) # positional input 1
|
|
>>> torch.onnx.dynamo_export(model, arg1).save(
|
|
... destination="exported_model.onnx",
|
|
... serializer=ProtobufExportOutputSerializer(),
|
|
... )
|
|
"""
|
|
...
|
|
|
|
|
|
class ProtobufExportOutputSerializer:
|
|
"""Serializes ONNX graph as Protobuf."""
|
|
|
|
@_beartype.beartype
|
|
def serialize(
|
|
self, export_output: ExportOutput, destination: io.BufferedIOBase
|
|
) -> None:
|
|
import onnx
|
|
|
|
if not isinstance(export_output.model_proto, onnx.ModelProto): # type: ignore[attr-defined]
|
|
raise ValueError("export_output.ModelProto is not an onnx.ModelProto")
|
|
destination.write(export_output.model_proto.SerializeToString())
|
|
|
|
|
|
class LargeProtobufExportOutputSerializer:
|
|
"""Serializes ONNX graph as Protobuf.
|
|
|
|
Fallback to serializing as Protobuf with external data for models larger than 2GB.
|
|
"""
|
|
|
|
_destination_path: Final[str]
|
|
|
|
def __init__(self, destination_path: str):
|
|
self._destination_path = destination_path
|
|
|
|
@_beartype.beartype
|
|
def serialize(
|
|
self, export_output: ExportOutput, destination: io.BufferedIOBase
|
|
) -> None:
|
|
"""`destination` is ignored. The model is saved to `self._destination_path` instead."""
|
|
import onnx
|
|
|
|
try:
|
|
onnx.save_model(export_output.model_proto, self._destination_path) # type: ignore[attr-defined]
|
|
except ValueError:
|
|
# ValueError: Message onnx.ModelProto exceeds maximum protobuf size of 2GB
|
|
# Fallback to serializing the model with external data.
|
|
onnx.save_model( # type: ignore[attr-defined]
|
|
export_output.model_proto,
|
|
self._destination_path,
|
|
save_as_external_data=True,
|
|
)
|
|
|
|
|
|
class ExportOutput:
|
|
"""An in-memory representation of a PyTorch model that has been exported to ONNX."""
|
|
|
|
_model_proto: Final[onnx.ModelProto] # type: ignore[name-defined]
|
|
_input_adapter: Final[io_adapter.InputAdapter]
|
|
_output_adapter: Final[io_adapter.OutputAdapter]
|
|
_diagnostic_context: Final[diagnostics.DiagnosticContext]
|
|
_fake_context: Final[Optional[ONNXFakeContext]]
|
|
_export_exception: Final[Optional[Exception]]
|
|
|
|
@_beartype.beartype
|
|
def __init__(
|
|
self,
|
|
model_proto: onnx.ModelProto, # type: ignore[name-defined]
|
|
input_adapter: io_adapter.InputAdapter,
|
|
output_adapter: io_adapter.OutputAdapter,
|
|
diagnostic_context: diagnostics.DiagnosticContext,
|
|
*,
|
|
fake_context: Optional[ONNXFakeContext] = None,
|
|
export_exception: Optional[Exception] = None,
|
|
):
|
|
self._model_proto = model_proto
|
|
self._input_adapter = input_adapter
|
|
self._output_adapter = output_adapter
|
|
self._diagnostic_context = diagnostic_context
|
|
self._fake_context = fake_context
|
|
self._export_exception = export_exception
|
|
|
|
@property
|
|
def model_proto(self) -> onnx.ModelProto: # type: ignore[name-defined]
|
|
"""The exported ONNX model as an :py:obj:`onnx.ModelProto`."""
|
|
|
|
if self._export_exception is not None:
|
|
raise self._export_exception
|
|
return self._model_proto
|
|
|
|
@property
|
|
def diagnostic_context(self) -> diagnostics.DiagnosticContext:
|
|
"""The diagnostic context associated with the export."""
|
|
|
|
return self._diagnostic_context
|
|
|
|
@property
|
|
def fake_context(self) -> Optional[ONNXFakeContext]:
|
|
"""The fake context associated with the export."""
|
|
|
|
return self._fake_context
|
|
|
|
@_beartype.beartype
|
|
def adapt_torch_inputs_to_onnx(
|
|
self, *model_args, **model_kwargs
|
|
) -> Sequence[Union[torch.Tensor, int, float, bool]]:
|
|
"""Converts the PyTorch model inputs to exported ONNX model inputs format.
|
|
|
|
Due to design differences, input/output format between PyTorch model and exported
|
|
ONNX model are often not the same. E.g., None is allowed for PyTorch model, but are
|
|
not supported by ONNX. Nested constructs of tensors are allowed for PyTorch model,
|
|
but only flattened tensors are supported by ONNX, etc.
|
|
|
|
The actual adapting steps are associated with each individual export. It
|
|
depends on the PyTorch model, the particular set of model_args and model_kwargs
|
|
used for the export, and export options.
|
|
|
|
This method replays the adapting steps recorded during export.
|
|
|
|
Args:
|
|
model_args: The PyTorch model inputs.
|
|
model_kwargs: The PyTorch model keyword inputs.
|
|
|
|
Returns:
|
|
A sequence of tensors converted from PyTorch model inputs.
|
|
|
|
Example::
|
|
|
|
# xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
|
|
>>> import torch
|
|
>>> import torch.onnx
|
|
>>> from typing import Dict, Tuple
|
|
>>> def func_with_nested_input_structure(
|
|
... x_dict: Dict[str, torch.Tensor],
|
|
... y_tuple: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
|
|
... ):
|
|
... if "a" in x_dict:
|
|
... x = x_dict["a"]
|
|
... elif "b" in x_dict:
|
|
... x = x_dict["b"]
|
|
... else:
|
|
... x = torch.randn(3)
|
|
...
|
|
... y1, (y2, y3) = y_tuple
|
|
...
|
|
... return x + y1 + y2 + y3
|
|
>>> x_dict = {"a": torch.tensor(1.)}
|
|
>>> y_tuple = (torch.tensor(2.), (torch.tensor(3.), torch.tensor(4.)))
|
|
>>> export_output = torch.onnx.dynamo_export(func_with_nested_input_structure, x_dict, y_tuple)
|
|
>>> print(x_dict, y_tuple)
|
|
{'a': tensor(1.)} (tensor(2.), (tensor(3.), tensor(4.)))
|
|
>>> print(export_output.adapt_torch_inputs_to_onnx(x_dict, y_tuple))
|
|
(tensor(1.), tensor(2.), tensor(3.), tensor(4.))
|
|
|
|
.. warning::
|
|
This API is experimental and is *NOT* backward-compatible.
|
|
|
|
"""
|
|
return self._input_adapter.apply(*model_args, **model_kwargs)
|
|
|
|
@_beartype.beartype
|
|
def adapt_torch_outputs_to_onnx(
|
|
self, model_outputs: Any
|
|
) -> Sequence[Union[torch.Tensor, int, float, bool]]:
|
|
"""Converts the PyTorch model outputs to exported ONNX model outputs format.
|
|
|
|
Due to design differences, input/output format between PyTorch model and exported
|
|
ONNX model are often not the same. E.g., None is allowed for PyTorch model, but are
|
|
not supported by ONNX. Nested constructs of tensors are allowed for PyTorch model,
|
|
but only flattened tensors are supported by ONNX, etc.
|
|
|
|
The actual adapting steps are associated with each individual export. It
|
|
depends on the PyTorch model, the particular set of model_args and model_kwargs
|
|
used for the export, and export options.
|
|
|
|
This method replays the adapting steps recorded during export.
|
|
|
|
Args:
|
|
model_outputs: The PyTorch model outputs.
|
|
|
|
Returns:
|
|
PyTorch model outputs in exported ONNX model outputs format.
|
|
|
|
Example::
|
|
|
|
# xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
|
|
>>> import torch
|
|
>>> import torch.onnx
|
|
>>> def func_returning_tuples(x, y, z):
|
|
... x = x + y
|
|
... y = y + z
|
|
... z = x + y
|
|
... return (x, (y, z))
|
|
>>> x = torch.tensor(1.)
|
|
>>> y = torch.tensor(2.)
|
|
>>> z = torch.tensor(3.)
|
|
>>> export_output = torch.onnx.dynamo_export(func_returning_tuples, x, y, z)
|
|
>>> pt_output = func_returning_tuples(x, y, z)
|
|
>>> print(pt_output)
|
|
(tensor(3.), (tensor(5.), tensor(8.)))
|
|
>>> print(export_output.adapt_torch_outputs_to_onnx(pt_output))
|
|
[tensor(3.), tensor(5.), tensor(8.)]
|
|
|
|
.. warning::
|
|
This API is experimental and is *NOT* backward-compatible.
|
|
|
|
"""
|
|
return self._output_adapter.apply(model_outputs)
|
|
|
|
@_beartype.beartype
|
|
def save(
|
|
self,
|
|
destination: Union[str, io.BufferedIOBase],
|
|
*,
|
|
model_state_dict: Optional[Union[Dict[str, Any], str]] = None,
|
|
serializer: Optional[ExportOutputSerializer] = None,
|
|
) -> None:
|
|
"""Saves the in-memory ONNX model to ``destination`` using specified ``serializer``.
|
|
|
|
Args:
|
|
destination: The destination to save the ONNX model. It can be either a string or a file-like object.
|
|
When used with ``model_state_dict``, it must be a string with a full path to the destination.
|
|
In that case, besides saving the ONNX model, a folder with "_initializers" suffix (without extension)
|
|
will be created to store the each initializer of the ONNX model in a separate file. For example, if the
|
|
destination is "/path/model.onnx", the initializers will be saved in "/path/model_initializers/" folder.
|
|
model_state_dict: The state_dict of the PyTorch model containing all weights on it.
|
|
It can be either a dict as returned by :meth:`model.state_dict`, or a string with a file name.
|
|
Required when :func:`enable_fake_mode` is used but real initializers are needed on the ONNX graph.
|
|
It can be either a string with the path to a checkpoint or a dictionary with the actual model state.
|
|
|
|
serializer: The serializer to use. If not specified, the model will be serialized as Protobuf.
|
|
"""
|
|
|
|
if serializer is None:
|
|
if isinstance(destination, str):
|
|
serializer = LargeProtobufExportOutputSerializer(destination)
|
|
else:
|
|
serializer = ProtobufExportOutputSerializer()
|
|
|
|
# Add initializers when symbolic tracing is enabled
|
|
_model_state_dict_files: List[Union[str, io.BytesIO]] = []
|
|
if model_state_dict is not None:
|
|
if isinstance(model_state_dict, dict):
|
|
model_state_dict_file = io.BytesIO()
|
|
torch.save(model_state_dict, model_state_dict_file)
|
|
model_state_dict_file.seek(0)
|
|
_model_state_dict_files.append(model_state_dict_file)
|
|
else:
|
|
isinstance(
|
|
model_state_dict, str
|
|
), "model_state_dict must be a path to the model's state_dict or the actual state_dict"
|
|
_model_state_dict_files.append(model_state_dict)
|
|
|
|
# Load state from previous model.load_state_dict() call within enable_fake_mode() context
|
|
if self._fake_context and self._fake_context.state_dict_paths:
|
|
for path in self._fake_context.state_dict_paths:
|
|
if path in _model_state_dict_files:
|
|
# ignore duplicate
|
|
continue
|
|
try:
|
|
extra_state_dict = torch.load(path)
|
|
extra_state_dict_file = io.BytesIO()
|
|
torch.save(extra_state_dict, extra_state_dict_file)
|
|
extra_state_dict_file.seek(0)
|
|
_model_state_dict_files.append(extra_state_dict_file)
|
|
except FileNotFoundError:
|
|
# It is ok to ignore transient state_dict file created within context manager
|
|
pass
|
|
|
|
if _model_state_dict_files:
|
|
if not isinstance(destination, str):
|
|
raise RuntimeError(
|
|
"`destination` must be a string with a path when `model_state_dict` is specified."
|
|
)
|
|
destination_path, destination_filename = os.path.split(destination)
|
|
onnx_model_location = destination_filename
|
|
onnx_initializer_location = (
|
|
destination_filename.split(".")[0] + "_initializers"
|
|
)
|
|
# TODO: Should this be part of the serializer?
|
|
fx_serialization.save_model_with_external_data(
|
|
destination_path,
|
|
onnx_model_location,
|
|
onnx_initializer_location,
|
|
tuple(_model_state_dict_files),
|
|
self.model_proto,
|
|
)
|
|
else:
|
|
if isinstance(destination, str):
|
|
with open(destination, "wb") as f:
|
|
serializer.serialize(self, f)
|
|
else:
|
|
try:
|
|
serializer.serialize(self, destination)
|
|
except ValueError:
|
|
raise ValueError(
|
|
"'destination' should be provided as a path-like string when saving a model larger than 2GB. "
|
|
"External tensor data will be saved alongside the model on disk."
|
|
)
|
|
|
|
@_beartype.beartype
|
|
def save_diagnostics(self, destination: str) -> None:
|
|
"""Saves the export diagnostics as a SARIF log to the specified destination path.
|
|
|
|
Args:
|
|
destination: The destination to save the diagnostics SARIF log.
|
|
It must have a `.sarif` extension.
|
|
|
|
Raises:
|
|
ValueError: If the destination path does not end with `.sarif` extension.
|
|
"""
|
|
if not destination.endswith(".sarif"):
|
|
message = f"'destination' must have a .sarif extension, got {destination}"
|
|
log.fatal(message)
|
|
raise ValueError(message)
|
|
|
|
self.diagnostic_context.dump(destination)
|
|
|
|
@classmethod
|
|
def _from_failure(
|
|
cls,
|
|
export_exception: Exception,
|
|
diagnostic_context: diagnostics.DiagnosticContext,
|
|
) -> Self:
|
|
"""
|
|
Creates an instance of :class:`ExportOutput` when the export process encounters a failure.
|
|
|
|
In case of a failed export, this method is used to encapsulate the exception
|
|
and associated diagnostic context within an :class:`ExportOutput` instance for
|
|
easier handling and debugging.
|
|
|
|
Args:
|
|
export_exception: The exception raised during the export process.
|
|
diagnostic_context: The context associated with diagnostics during export.
|
|
|
|
Returns:
|
|
An instance of :class:`ExportOutput` representing the failed export output.
|
|
"""
|
|
# Defer `import onnx` out of `import torch` path
|
|
# https://github.com/pytorch/pytorch/issues/103764
|
|
import onnx
|
|
|
|
return ExportOutput(
|
|
onnx.ModelProto(), # type: ignore[attr-defined]
|
|
io_adapter.InputAdapter(),
|
|
io_adapter.OutputAdapter(),
|
|
diagnostic_context,
|
|
export_exception=export_exception,
|
|
)
|
|
|
|
|
|
class FXGraphExtractor(abc.ABC):
|
|
"""Abstract interface for FX graph extractor engines.
|
|
This class isolates FX extraction logic from the rest of the export logic.
|
|
That allows a single ONNX exporter that can leverage different FX graphs."""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.input_adapter: io_adapter.InputAdapter = io_adapter.InputAdapter()
|
|
self.output_adapter: io_adapter.OutputAdapter = io_adapter.OutputAdapter()
|
|
|
|
@abc.abstractmethod
|
|
def generate_fx(
|
|
self,
|
|
options: ResolvedExportOptions,
|
|
model: Union[torch.nn.Module, Callable],
|
|
model_args: Sequence[Any],
|
|
model_kwargs: Mapping[str, Any],
|
|
) -> torch.fx.GraphModule:
|
|
"""Analyzes user ``model`` and generates a FX graph.
|
|
Args:
|
|
options: The export options.
|
|
model: The user model.
|
|
model_args: The model's positional input arguments.
|
|
model_kwargs: The model's keyword input arguments.
|
|
Returns:
|
|
The generated FX Graph.
|
|
"""
|
|
...
|
|
|
|
|
|
class Exporter:
|
|
@_beartype.beartype
|
|
def __init__(
|
|
self,
|
|
options: Union[ExportOptions, ResolvedExportOptions],
|
|
model: Union[torch.nn.Module, Callable],
|
|
model_args: Sequence[Any],
|
|
model_kwargs: Mapping[str, Any],
|
|
):
|
|
self.options = ResolvedExportOptions(options)
|
|
assert self.options is not None
|
|
|
|
self.model = model
|
|
self.model_args = model_args
|
|
self.model_kwargs = model_kwargs
|
|
|
|
# TODO: Retire FXSymbolicTracer
|
|
# NOTE: FXSymbolicTracer would fail in this assert, as it does not use `enable_fake_mode`
|
|
from torch.onnx._internal.fx import fx_symbolic_graph_extractor
|
|
|
|
if not isinstance(
|
|
self.options.fx_tracer, fx_symbolic_graph_extractor.FXSymbolicTracer
|
|
):
|
|
self._assert_fake_tensor_mode()
|
|
|
|
def export(self) -> ExportOutput:
|
|
with self.options.diagnostic_context:
|
|
graph_module = self.options.fx_tracer.generate_fx(
|
|
self.options, self.model, self.model_args, self.model_kwargs
|
|
)
|
|
|
|
updated_model_args = self.options.fx_tracer.input_adapter.apply(
|
|
*self.model_args, **self.model_kwargs
|
|
)
|
|
|
|
# TODO: Design the passes API
|
|
graph_module = pre_export_passes(
|
|
self.options, self.model, graph_module, updated_model_args
|
|
)
|
|
|
|
# TODO: Defer `import onnxscript` out of `import torch` path
|
|
# https://github.com/pytorch/pytorch/issues/103764
|
|
from torch.onnx._internal.fx import fx_onnx_interpreter
|
|
|
|
fx_interpreter = fx_onnx_interpreter.FxOnnxInterpreter(
|
|
diagnostic_context=self.options.diagnostic_context
|
|
)
|
|
onnxscript_graph = fx_interpreter.run(
|
|
fx_graph_module=graph_module,
|
|
onnxfunction_dispatcher=self.options.onnxfunction_dispatcher,
|
|
op_level_debug=self.options.op_level_debug,
|
|
)
|
|
|
|
# NOTE: Filter out the initializers with fake tensors when it's fake_mode exporting.
|
|
# Otherwise, the ONNX exporter will fail: RuntimeError: basic_string::_M_construct null
|
|
# not valid.
|
|
# Concrete data is expected to be filled for those initializers later during `ExportOutput.save`.
|
|
if self.options.fake_context is not None:
|
|
initializers_with_real_tensors: Dict[str, torch.Tensor] = {}
|
|
for (
|
|
initializer_name,
|
|
initializer,
|
|
) in onnxscript_graph.initializers.items():
|
|
if not isinstance(initializer, torch._subclasses.FakeTensor):
|
|
initializers_with_real_tensors[initializer_name] = initializer
|
|
onnxscript_graph.initializers = initializers_with_real_tensors
|
|
|
|
# Export TorchScript graph to ONNX ModelProto.
|
|
onnx_model = onnxscript_graph.to_model_proto(
|
|
self.options.onnx_registry.opset_version,
|
|
)
|
|
|
|
return torch.onnx.ExportOutput(
|
|
onnx_model,
|
|
self.options.fx_tracer.input_adapter,
|
|
self.options.fx_tracer.output_adapter,
|
|
self.options.diagnostic_context,
|
|
fake_context=self.options.fake_context,
|
|
)
|
|
|
|
def _assert_fake_tensor_mode(self):
|
|
"""Asserts that the model and its input do not contain fake tensors."""
|
|
|
|
# Case 1: Model with fake inputs/weights and without enabling fake mode
|
|
has_any_fake_tensor = pytree.tree_any(
|
|
lambda x: isinstance(x, torch._subclasses.FakeTensor),
|
|
(self.model_args, self.model_kwargs),
|
|
)
|
|
has_any_fake_param_or_buffer = False
|
|
if isinstance(self.model, torch.nn.Module):
|
|
has_any_fake_param_or_buffer = pytree.tree_any(
|
|
lambda x: isinstance(x, torch._subclasses.FakeTensor),
|
|
(self.model.parameters(), self.model.buffers()),
|
|
)
|
|
if (
|
|
has_any_fake_tensor or has_any_fake_param_or_buffer
|
|
) and not self.options.fake_context:
|
|
raise RuntimeError(
|
|
"Cannot export a model with fake inputs/weights without enabling fake mode.",
|
|
)
|
|
# Case 2: Model with non fake inputs/weights and enabled fake mode
|
|
has_any_non_fake_tensors = pytree.tree_any(
|
|
lambda x: isinstance(x, torch.Tensor)
|
|
and not isinstance(x, torch._subclasses.FakeTensor),
|
|
(self.model_args, self.model_kwargs),
|
|
)
|
|
has_any_non_fake_param_or_buffer = False
|
|
if isinstance(self.model, torch.nn.Module):
|
|
has_any_non_fake_param_or_buffer = pytree.tree_any(
|
|
lambda x: isinstance(x, torch.Tensor)
|
|
and not isinstance(x, torch._subclasses.FakeTensor),
|
|
(self.model.parameters(), self.model.buffers()),
|
|
)
|
|
if (
|
|
has_any_non_fake_tensors or has_any_non_fake_param_or_buffer
|
|
) and self.options.fake_context:
|
|
raise RuntimeError(
|
|
"Cannot export a model with non fake inputs/weights and enabled fake mode.",
|
|
)
|
|
|
|
|
|
class UnsatisfiedDependencyError(RuntimeError):
|
|
"""Raised when an ONNX exporter dependency cannot be satisfied."""
|
|
|
|
def __init__(self, package_name: str, message: str):
|
|
super().__init__(message)
|
|
self.package_name = package_name
|
|
|
|
|
|
class OnnxExporterError(RuntimeError):
|
|
"""Raised when an ONNX exporter error occurs.
|
|
|
|
This exception is thrown when there's an error during the ONNX export process.
|
|
It encapsulates the :class:`ExportOutput` object generated until the failure, allowing
|
|
access to the partial export results and associated metadata.
|
|
"""
|
|
|
|
export_output: Final[ExportOutput]
|
|
|
|
def __init__(self, export_output: ExportOutput, message: str):
|
|
"""
|
|
Initializes the OnnxExporterError with the given export output and message.
|
|
|
|
Args:
|
|
export_output (ExportOutput): The partial results of the ONNX export.
|
|
message (str): The error message to be displayed.
|
|
"""
|
|
super().__init__(message)
|
|
self.export_output = export_output
|
|
|
|
|
|
@_beartype.beartype
|
|
def _assert_dependencies(export_options: ResolvedExportOptions):
|
|
opset_version = export_options.onnx_registry.opset_version
|
|
|
|
def missing_package(package_name: str, exc_info: logging._ExcInfoType):
|
|
message = (
|
|
f"Please install the `{package_name}` package "
|
|
f"(e.g. `python -m pip install {package_name}`)."
|
|
)
|
|
log.fatal(message, exc_info=exc_info)
|
|
return UnsatisfiedDependencyError(package_name, message)
|
|
|
|
def missing_opset(package_name: str):
|
|
message = (
|
|
f"The installed `{package_name}` does not support the specified ONNX opset "
|
|
f"version {opset_version}. Install a newer `{package_name}` package or "
|
|
f"specify an older opset version."
|
|
)
|
|
log.fatal(message)
|
|
return UnsatisfiedDependencyError(package_name, message)
|
|
|
|
try:
|
|
import onnx
|
|
except ImportError as e:
|
|
raise missing_package("onnx", e) from e
|
|
|
|
if onnx.defs.onnx_opset_version() < opset_version:
|
|
raise missing_opset("onnx")
|
|
|
|
try:
|
|
# PyTorch runs lintrunner in CI without onnxscript installed
|
|
import onnxscript # type: ignore[import]
|
|
except ImportError as e:
|
|
raise missing_package("onnxscript", e) from e
|
|
|
|
if not isinstance(
|
|
onnxscript.onnx_opset.all_opsets[("", opset_version)],
|
|
onnxscript.values.Opset,
|
|
):
|
|
raise missing_opset("onnxscript")
|
|
|
|
|
|
@_beartype.beartype
|
|
def dynamo_export(
|
|
model: Union[torch.nn.Module, Callable],
|
|
/,
|
|
*model_args,
|
|
export_options: Optional[ExportOptions] = None,
|
|
**model_kwargs,
|
|
) -> ExportOutput:
|
|
"""Export a torch.nn.Module to an ONNX graph.
|
|
|
|
Args:
|
|
model: The PyTorch model to be exported to ONNX.
|
|
model_args: Positional inputs to ``model``.
|
|
model_kwargs: Keyword inputs to ``model``.
|
|
export_options: Options to influence the export to ONNX.
|
|
|
|
Returns:
|
|
An in-memory representation of the exported ONNX model.
|
|
|
|
**Example 1 - Simplest export**
|
|
::
|
|
|
|
class MyModel(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(2, 2)
|
|
def forward(self, x, bias=None):
|
|
out = self.linear(x)
|
|
out = out + bias
|
|
return out
|
|
model = MyModel()
|
|
kwargs = {"bias": 3.}
|
|
args = (torch.randn(2, 2, 2),)
|
|
export_output = torch.onnx.dynamo_export(
|
|
model,
|
|
*args,
|
|
**kwargs).save("my_simple_model.onnx")
|
|
|
|
**Example 2 - Exporting with dynamic shapes**
|
|
::
|
|
|
|
# The previous model can be exported with dynamic shapes
|
|
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
|
|
export_output = torch.onnx.dynamo_export(
|
|
model,
|
|
*args,
|
|
**kwargs,
|
|
export_options=export_options)
|
|
export_output.save("my_dynamic_model.onnx")
|
|
|
|
|
|
By printing input dynamic dimensions we can see the input shape is no longer (2,2,2)
|
|
::
|
|
|
|
>>> print(export_output.model_proto.graph.input[0])
|
|
name: "arg0"
|
|
type {
|
|
tensor_type {
|
|
elem_type: 1
|
|
shape {
|
|
dim {
|
|
dim_param: "arg0_dim_0"
|
|
}
|
|
dim {
|
|
dim_param: "arg0_dim_1"
|
|
}
|
|
dim {
|
|
dim_param: "arg0_dim_2"
|
|
}
|
|
}
|
|
}
|
|
}
|
|
"""
|
|
|
|
resolved_export_options = (
|
|
export_options
|
|
if isinstance(export_options, ResolvedExportOptions)
|
|
else ResolvedExportOptions(export_options)
|
|
)
|
|
|
|
_assert_dependencies(resolved_export_options)
|
|
|
|
try:
|
|
return Exporter(
|
|
options=resolved_export_options,
|
|
model=model,
|
|
model_args=model_args,
|
|
model_kwargs=model_kwargs,
|
|
).export()
|
|
except Exception as e:
|
|
sarif_report_path = _DEFAULT_FAILED_EXPORT_SARIF_LOG_PATH
|
|
resolved_export_options.diagnostic_context.dump(sarif_report_path)
|
|
message = (
|
|
f"Failed to export the model to ONNX. Generating SARIF report at '{sarif_report_path}'. "
|
|
"SARIF is a standard format for the output of static analysis tools. "
|
|
"SARIF logs can be loaded in VS Code SARIF viewer extension, "
|
|
"or SARIF web viewer (https://microsoft.github.io/sarif-web-component/). "
|
|
f"Please report a bug on PyTorch Github: {_PYTORCH_GITHUB_ISSUES_URL}"
|
|
)
|
|
raise OnnxExporterError(
|
|
ExportOutput._from_failure(e, resolved_export_options.diagnostic_context),
|
|
message,
|
|
) from e
|
|
|
|
|
|
@_beartype.beartype
|
|
def pre_export_passes(
|
|
options: ResolvedExportOptions,
|
|
original_model: Union[torch.nn.Module, Callable],
|
|
fx_module: torch.fx.GraphModule,
|
|
fx_module_args: Sequence[Any],
|
|
):
|
|
# TODO: Import here to prevent circular dependency
|
|
from torch.onnx._internal.fx import analysis, passes
|
|
|
|
diagnostic_context = options.diagnostic_context
|
|
|
|
# Apply decomposition table to the input graph.
|
|
module = passes.Decompose(
|
|
diagnostic_context,
|
|
fx_module,
|
|
options.decomposition_table,
|
|
enable_dynamic_axes=options.dynamic_shapes,
|
|
allow_fake_constant=options.fake_context is not None,
|
|
).run(*fx_module_args)
|
|
|
|
# ONNX does not support views and mutations.
|
|
# Functionalize to get a semantically equivalent graph without mutations.
|
|
module = passes.Functionalize(
|
|
diagnostic_context,
|
|
module,
|
|
enable_dynamic_axes=options.dynamic_shapes,
|
|
allow_fake_constant=options.fake_context is not None,
|
|
).run(*fx_module_args)
|
|
|
|
# Input mutations are detected and distilled after `Functionalize` pass.
|
|
# Remove them since ONNX inference does not need them.
|
|
module = passes.RemoveInputMutation(diagnostic_context, module).run(*fx_module_args)
|
|
|
|
# ONNX does not support concept of (implicit) type promotion.
|
|
# Insert type casts explicitly where needed.
|
|
module = passes.InsertTypePromotion(diagnostic_context, module).run()
|
|
|
|
analysis.UnsupportedFxNodesAnalysis(
|
|
diagnostic_context, module, options.onnxfunction_dispatcher
|
|
).analyze(infra.levels.ERROR)
|
|
|
|
if isinstance(original_model, torch.nn.Module):
|
|
module = passes.RestoreParameterAndBufferNames(
|
|
diagnostic_context, module, original_model
|
|
).run()
|
|
|
|
# This operation should be invoked as the last pre export pass.
|
|
# See [NOTE: Modularize pass ordering]
|
|
module = passes.Modularize(diagnostic_context, module).run()
|
|
|
|
# ONNX does not support None inputs. During graph building, all None inputs
|
|
# are removed. Here we register this step to input adapter.
|
|
options.fx_tracer.input_adapter.append_step(io_adapter.RemoveNoneInputStep())
|
|
|
|
# NOTE: temp workaround for https://github.com/pytorch/pytorch/issues/99534
|
|
# Dynamo doesn't support non-tensor inputs.
|
|
options.fx_tracer.input_adapter.append_step(io_adapter.RemoveNonTensorInputStep())
|
|
|
|
# ONNX does not support complex inputs. During graph building, all complex inputs
|
|
# are converted to real representation inputs. Here we register this step to
|
|
# input/output adapter.
|
|
options.fx_tracer.input_adapter.append_step(
|
|
io_adapter.ConvertComplexToRealRepresentationInputStep()
|
|
)
|
|
|
|
# ONNX can't represent collection types (e.g., dictionary, tuple of tuple of
|
|
# tensor, etc), we flatten the collection and register each element as output.
|
|
options.fx_tracer.output_adapter.append_step(io_adapter.FlattenOutputStep())
|
|
|
|
# Output post-processing steps should happen after `FlattenOutputStep`.
|
|
options.fx_tracer.output_adapter.append_step(
|
|
io_adapter.ConvertComplexToRealRepresentationOutputStep()
|
|
)
|
|
|
|
return module
|
|
|
|
|
|
__all__ = [
|
|
"ExportOptions",
|
|
"ExportOutput",
|
|
"ExportOutputSerializer",
|
|
"UnsatisfiedDependencyError",
|
|
"dynamo_export",
|
|
"OnnxExporterError",
|
|
"enable_fake_mode",
|
|
"OnnxRegistry",
|
|
"DiagnosticOptions",
|
|
]
|