mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
This is the first phase of the new ONNX exporter API for exporting from TorchDynamo and FX, and represents the beginning of a new era for exporting ONNX from PyTorch.
The API here is a starting point upon which we will layer more capability and expressiveness in subsequent phases. This first phase introduces the following into `torch.onnx`:
```python
dynamo_export(
model: torch.nn.Module,
/,
*model_args,
export_options: Optional[ExportOptions] = None,
**model_kwargs,
) -> ExportOutput:
...
class ExportOptions:
opset_version: Optional[int] = None
dynamic_shapes: Optional[bool] = None
logger: Optional[logging.Logger] = None
class ExportOutputSerializer(Protocol):
def serialize(
self,
export_output: ExportOutput,
destination: io.BufferedIOBase,
) -> None:
...
class ExportOutput:
model_proto: onnx.ModelProto
def save(
self,
destination: Union[str, io.BufferedIOBase],
*,
serializer: Optional[ExportOutputSerializer] = None,
) -> None:
...
```
In addition to the API in the first commit on this PR, we have a few experiments for exporting Dynamo and FX to ONNX that this PR rationalizes through the new Exporter API and adjusts tests to use the new API.
- A base `FXGraphModuleExporter` exporter from which all derive:
- `DynamoExportExporter`: uses dynamo.export to acquire FX graph
- `DynamoOptimizeExporter`: uses dynamo.optimize to acquire FX graph
- `FXSymbolicTraceExporter`: uses FX symbolic tracing
The `dynamo_export` API currently uses `DynamoOptimizeExporter`.
### Next Steps (subsequent PRs):
* Combine `DynamoExportExporter` and `DynamoOptimizeExporter` into a single `DynamoExporter`.
* Make it easy to test `FXSymbolicTraceExporter` through the same API; eventually `FXSymbolicTraceExporter` goes away entirely when the Dynamo approach works for large models. We want to keep `FXSymbolicTraceExporter` around for now for experimenting and internal use.
* Parameterize (on `ExportOptions`) and consolidate Dynamo exporter tests.
- This PR intentionally leaves the existing tests unchanged as much as possible except for the necessary plumbing.
* Subsequent API phases:
- Diagnostics
- Registry, dispatcher, and Custom Ops
- Passes
- Dynamic shapes
Fixes #94774
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97920
Approved by: https://github.com/justinchuby, https://github.com/titaiwangms, https://github.com/thiagocrepaldi, https://github.com/shubhambhokare1
149 lines
3.6 KiB
Python
149 lines
3.6 KiB
Python
"""ONNX exporter."""
|
|
|
|
from torch import _C
|
|
from torch._C import _onnx as _C_onnx
|
|
from torch._C._onnx import (
|
|
_CAFFE2_ATEN_FALLBACK,
|
|
OperatorExportTypes,
|
|
TensorProtoDataType,
|
|
TrainingMode,
|
|
)
|
|
|
|
from . import ( # usort:skip. Keep the order instead of sorting lexicographically
|
|
_deprecation,
|
|
errors,
|
|
symbolic_caffe2,
|
|
symbolic_helper,
|
|
symbolic_opset7,
|
|
symbolic_opset8,
|
|
symbolic_opset9,
|
|
symbolic_opset10,
|
|
symbolic_opset11,
|
|
symbolic_opset12,
|
|
symbolic_opset13,
|
|
symbolic_opset14,
|
|
symbolic_opset15,
|
|
symbolic_opset16,
|
|
symbolic_opset17,
|
|
symbolic_opset18,
|
|
utils,
|
|
)
|
|
|
|
# TODO(After 1.13 release): Remove the deprecated SymbolicContext
|
|
from ._exporter_states import ExportTypes, SymbolicContext
|
|
from ._type_utils import JitScalarType
|
|
from .errors import CheckerError # Backwards compatibility
|
|
from .utils import (
|
|
_optimize_graph,
|
|
_run_symbolic_function,
|
|
_run_symbolic_method,
|
|
export,
|
|
export_to_pretty_string,
|
|
is_in_onnx_export,
|
|
register_custom_op_symbolic,
|
|
select_model_mode_for_export,
|
|
unregister_custom_op_symbolic,
|
|
)
|
|
|
|
from ._internal.exporter import ( # usort:skip. needs to be last to avoid circular import
|
|
ExportOptions,
|
|
ExportOutput,
|
|
ExportOutputSerializer,
|
|
dynamo_export,
|
|
)
|
|
|
|
__all__ = [
|
|
# Modules
|
|
"symbolic_helper",
|
|
"utils",
|
|
"errors",
|
|
# All opsets
|
|
"symbolic_caffe2",
|
|
"symbolic_opset7",
|
|
"symbolic_opset8",
|
|
"symbolic_opset9",
|
|
"symbolic_opset10",
|
|
"symbolic_opset11",
|
|
"symbolic_opset12",
|
|
"symbolic_opset13",
|
|
"symbolic_opset14",
|
|
"symbolic_opset15",
|
|
"symbolic_opset16",
|
|
"symbolic_opset17",
|
|
"symbolic_opset18",
|
|
# Enums
|
|
"ExportTypes",
|
|
"OperatorExportTypes",
|
|
"TrainingMode",
|
|
"TensorProtoDataType",
|
|
"JitScalarType",
|
|
# Public functions
|
|
"export",
|
|
"export_to_pretty_string",
|
|
"is_in_onnx_export",
|
|
"select_model_mode_for_export",
|
|
"register_custom_op_symbolic",
|
|
"unregister_custom_op_symbolic",
|
|
"disable_log",
|
|
"enable_log",
|
|
# Errors
|
|
"CheckerError", # Backwards compatibility
|
|
# Dynamo Exporter
|
|
"ExportOptions",
|
|
"ExportOutput",
|
|
"ExportOutputSerializer",
|
|
"dynamo_export",
|
|
]
|
|
|
|
# Set namespace for exposed private names
|
|
ExportTypes.__module__ = "torch.onnx"
|
|
JitScalarType.__module__ = "torch.onnx"
|
|
ExportOptions.__module__ = "torch.onnx"
|
|
ExportOutput.__module__ = "torch.onnx"
|
|
ExportOutputSerializer.__module__ = "torch.onnx"
|
|
dynamo_export.__module__ = "torch.onnx"
|
|
|
|
producer_name = "pytorch"
|
|
producer_version = _C_onnx.PRODUCER_VERSION
|
|
|
|
|
|
@_deprecation.deprecated(
|
|
since="1.12.0", removed_in="2.0", instructions="use `torch.onnx.export` instead"
|
|
)
|
|
def _export(*args, **kwargs):
|
|
return utils._export(*args, **kwargs)
|
|
|
|
|
|
# TODO(justinchuby): Deprecate these logging functions in favor of the new diagnostic module.
|
|
|
|
# Returns True iff ONNX logging is turned on.
|
|
is_onnx_log_enabled = _C._jit_is_onnx_log_enabled
|
|
|
|
|
|
def enable_log() -> None:
|
|
r"""Enables ONNX logging."""
|
|
_C._jit_set_onnx_log_enabled(True)
|
|
|
|
|
|
def disable_log() -> None:
|
|
r"""Disables ONNX logging."""
|
|
_C._jit_set_onnx_log_enabled(False)
|
|
|
|
|
|
"""Sets output stream for ONNX logging.
|
|
|
|
Args:
|
|
stream_name (str, default "stdout"): Only 'stdout' and 'stderr' are supported
|
|
as ``stream_name``.
|
|
"""
|
|
set_log_stream = _C._jit_set_onnx_log_output_stream
|
|
|
|
|
|
"""A simple logging facility for ONNX exporter.
|
|
|
|
Args:
|
|
args: Arguments are converted to string, concatenated together with a newline
|
|
character appended to the end, and flushed to output stream.
|
|
"""
|
|
log = _C._jit_onnx_log
|