pytorch/torch/onnx/__init__.py
Aaron Bockover 558e5a240e Introduce torch.onnx.dynamo_export API (#97920)
This is the first phase of the new ONNX exporter API for exporting from TorchDynamo and FX, and represents the beginning of a new era for exporting ONNX from PyTorch.

The API here is a starting point upon which we will layer more capability and expressiveness in subsequent phases. This first phase introduces the following into `torch.onnx`:

```python
dynamo_export(
    model: torch.nn.Module,
    /,
    *model_args,
    export_options: Optional[ExportOptions] = None,
    **model_kwargs,
) -> ExportOutput:
    ...

class ExportOptions:
    opset_version: Optional[int] = None
    dynamic_shapes: Optional[bool] = None
    logger: Optional[logging.Logger] = None

class ExportOutputSerializer(Protocol):
    def serialize(
        self,
        export_output: ExportOutput,
        destination: io.BufferedIOBase,
    ) -> None:
        ...

class ExportOutput:
    model_proto: onnx.ModelProto

    def save(
        self,
        destination: Union[str, io.BufferedIOBase],
        *,
        serializer: Optional[ExportOutputSerializer] = None,
    ) -> None:
        ...
```

In addition to the API in the first commit on this PR, we have a few experiments for exporting Dynamo and FX to ONNX that this PR rationalizes through the new Exporter API and adjusts tests to use the new API.

- A base `FXGraphModuleExporter` exporter from which all derive:
  - `DynamoExportExporter`: uses dynamo.export to acquire FX graph
  - `DynamoOptimizeExporter`: uses dynamo.optimize to acquire FX graph
  - `FXSymbolicTraceExporter`: uses FX symbolic tracing

The `dynamo_export` API currently uses `DynamoOptimizeExporter`.

### Next Steps (subsequent PRs):

* Combine `DynamoExportExporter` and `DynamoOptimizeExporter` into a single `DynamoExporter`.
* Make it easy to test `FXSymbolicTraceExporter` through the same API; eventually `FXSymbolicTraceExporter` goes away entirely when the Dynamo approach works for large models. We want to keep `FXSymbolicTraceExporter` around for now for experimenting and internal use.
* Parameterize (on `ExportOptions`) and consolidate Dynamo exporter tests.
  - This PR intentionally leaves the existing tests unchanged as much as possible except for the necessary plumbing.
* Subsequent API phases:
  - Diagnostics
  - Registry, dispatcher, and Custom Ops
  - Passes
  - Dynamic shapes

Fixes #94774

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97920
Approved by: https://github.com/justinchuby, https://github.com/titaiwangms, https://github.com/thiagocrepaldi, https://github.com/shubhambhokare1
2023-04-04 18:13:29 +00:00

149 lines
3.6 KiB
Python

"""ONNX exporter."""
from torch import _C
from torch._C import _onnx as _C_onnx
from torch._C._onnx import (
_CAFFE2_ATEN_FALLBACK,
OperatorExportTypes,
TensorProtoDataType,
TrainingMode,
)
from . import ( # usort:skip. Keep the order instead of sorting lexicographically
_deprecation,
errors,
symbolic_caffe2,
symbolic_helper,
symbolic_opset7,
symbolic_opset8,
symbolic_opset9,
symbolic_opset10,
symbolic_opset11,
symbolic_opset12,
symbolic_opset13,
symbolic_opset14,
symbolic_opset15,
symbolic_opset16,
symbolic_opset17,
symbolic_opset18,
utils,
)
# TODO(After 1.13 release): Remove the deprecated SymbolicContext
from ._exporter_states import ExportTypes, SymbolicContext
from ._type_utils import JitScalarType
from .errors import CheckerError # Backwards compatibility
from .utils import (
_optimize_graph,
_run_symbolic_function,
_run_symbolic_method,
export,
export_to_pretty_string,
is_in_onnx_export,
register_custom_op_symbolic,
select_model_mode_for_export,
unregister_custom_op_symbolic,
)
from ._internal.exporter import ( # usort:skip. needs to be last to avoid circular import
ExportOptions,
ExportOutput,
ExportOutputSerializer,
dynamo_export,
)
__all__ = [
# Modules
"symbolic_helper",
"utils",
"errors",
# All opsets
"symbolic_caffe2",
"symbolic_opset7",
"symbolic_opset8",
"symbolic_opset9",
"symbolic_opset10",
"symbolic_opset11",
"symbolic_opset12",
"symbolic_opset13",
"symbolic_opset14",
"symbolic_opset15",
"symbolic_opset16",
"symbolic_opset17",
"symbolic_opset18",
# Enums
"ExportTypes",
"OperatorExportTypes",
"TrainingMode",
"TensorProtoDataType",
"JitScalarType",
# Public functions
"export",
"export_to_pretty_string",
"is_in_onnx_export",
"select_model_mode_for_export",
"register_custom_op_symbolic",
"unregister_custom_op_symbolic",
"disable_log",
"enable_log",
# Errors
"CheckerError", # Backwards compatibility
# Dynamo Exporter
"ExportOptions",
"ExportOutput",
"ExportOutputSerializer",
"dynamo_export",
]
# Set namespace for exposed private names
ExportTypes.__module__ = "torch.onnx"
JitScalarType.__module__ = "torch.onnx"
ExportOptions.__module__ = "torch.onnx"
ExportOutput.__module__ = "torch.onnx"
ExportOutputSerializer.__module__ = "torch.onnx"
dynamo_export.__module__ = "torch.onnx"
producer_name = "pytorch"
producer_version = _C_onnx.PRODUCER_VERSION
@_deprecation.deprecated(
since="1.12.0", removed_in="2.0", instructions="use `torch.onnx.export` instead"
)
def _export(*args, **kwargs):
return utils._export(*args, **kwargs)
# TODO(justinchuby): Deprecate these logging functions in favor of the new diagnostic module.
# Returns True iff ONNX logging is turned on.
is_onnx_log_enabled = _C._jit_is_onnx_log_enabled
def enable_log() -> None:
r"""Enables ONNX logging."""
_C._jit_set_onnx_log_enabled(True)
def disable_log() -> None:
r"""Disables ONNX logging."""
_C._jit_set_onnx_log_enabled(False)
"""Sets output stream for ONNX logging.
Args:
stream_name (str, default "stdout"): Only 'stdout' and 'stderr' are supported
as ``stream_name``.
"""
set_log_stream = _C._jit_set_onnx_log_output_stream
"""A simple logging facility for ONNX exporter.
Args:
args: Arguments are converted to string, concatenated together with a newline
character appended to the end, and flushed to output stream.
"""
log = _C._jit_onnx_log