[ONNX] Remove ExportTypes (#137789)

Remove deprecated ExportTypes and the `_exporter_states` module. Only protobuf (default) is supported going forward.

Differential Revision: [D64412947](https://our.internmc.facebook.com/intern/diff/D64412947)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137789
Approved by: https://github.com/titaiwangms, https://github.com/xadupre
This commit is contained in:
Justin Chu 2024-10-11 09:53:16 -07:00 committed by PyTorch MergeBot
parent af0bc75460
commit 6e38c87ad0
4 changed files with 15 additions and 65 deletions

View File

@ -24,7 +24,6 @@ __all__ = [
"symbolic_opset19",
"symbolic_opset20",
# Enums
"ExportTypes",
"OperatorExportTypes",
"TrainingMode",
"TensorProtoDataType",
@ -57,7 +56,6 @@ from torch import _C
from torch._C import _onnx as _C_onnx
from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode
from ._exporter_states import ExportTypes
from ._internal.exporter._onnx_program import ONNXProgram
from ._internal.onnxruntime import (
is_onnxrt_backend_supported,
@ -115,7 +113,6 @@ if TYPE_CHECKING:
# Set namespace for exposed private names
DiagnosticOptions.__module__ = "torch.onnx"
ExportOptions.__module__ = "torch.onnx"
ExportTypes.__module__ = "torch.onnx"
JitScalarType.__module__ = "torch.onnx"
ONNXProgram.__module__ = "torch.onnx"
ONNXRuntimeOptions.__module__ = "torch.onnx"

View File

@ -4,19 +4,21 @@
from __future__ import annotations
import glob
import io
import os
import shutil
import zipfile
from typing import Any, Mapping
from typing import Any, Mapping, TYPE_CHECKING
import torch
import torch.jit._trace
import torch.serialization
from torch.onnx import _constants, _exporter_states, errors
from torch.onnx import errors
from torch.onnx._internal import jit_utils, registration
if TYPE_CHECKING:
import io
def export_as_test_case(
model_bytes: bytes, inputs_data, outputs_data, name: str, dir: str
) -> str:
@ -54,7 +56,6 @@ def export_as_test_case(
_export_file(
model_bytes,
os.path.join(test_case_dir, "model.onnx"),
_exporter_states.ExportTypes.PROTOBUF_FILE,
{},
)
data_set_dir = os.path.join(test_case_dir, "test_data_set_0")
@ -163,47 +164,12 @@ def export_data(data, value_info_proto, f: str) -> None:
def _export_file(
model_bytes: bytes,
f: io.BytesIO | str,
export_type: str,
export_map: Mapping[str, bytes],
) -> None:
"""export/write model bytes into directory/protobuf/zip"""
if export_type == _exporter_states.ExportTypes.PROTOBUF_FILE:
assert len(export_map) == 0
with torch.serialization._open_file_like(f, "wb") as opened_file:
opened_file.write(model_bytes)
elif export_type in {
_exporter_states.ExportTypes.ZIP_ARCHIVE,
_exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE,
}:
compression = (
zipfile.ZIP_DEFLATED
if export_type == _exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE
else zipfile.ZIP_STORED
)
with zipfile.ZipFile(f, "w", compression=compression) as z:
z.writestr(_constants.ONNX_ARCHIVE_MODEL_PROTO_NAME, model_bytes)
for k, v in export_map.items():
z.writestr(k, v)
elif export_type == _exporter_states.ExportTypes.DIRECTORY:
if isinstance(f, io.BytesIO) or not os.path.isdir(f): # type: ignore[arg-type]
raise ValueError(
f"f should be directory when export_type is set to DIRECTORY, instead get type(f): {type(f)}"
)
if not os.path.exists(f): # type: ignore[arg-type]
os.makedirs(f) # type: ignore[arg-type]
model_proto_file = os.path.join(f, _constants.ONNX_ARCHIVE_MODEL_PROTO_NAME) # type: ignore[arg-type]
with torch.serialization._open_file_like(model_proto_file, "wb") as opened_file:
opened_file.write(model_bytes)
for k, v in export_map.items():
weight_proto_file = os.path.join(f, k) # type: ignore[arg-type]
with torch.serialization._open_file_like(
weight_proto_file, "wb"
) as opened_file:
opened_file.write(v)
else:
raise ValueError("Unknown export type")
assert len(export_map) == 0
with torch.serialization._open_file_like(f, "wb") as opened_file:
opened_file.write(model_bytes)
def _add_onnxscript_fn(

View File

@ -20,13 +20,7 @@ import torch._C._onnx as _C_onnx
import torch.jit._trace
import torch.serialization
from torch import _C
from torch.onnx import ( # noqa: F401
_constants,
_deprecation,
_exporter_states,
errors,
symbolic_helper,
)
from torch.onnx import _constants, _deprecation, errors, symbolic_helper # noqa: F401
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import diagnostics, jit_utils, onnx_proto_utils, registration
@ -1423,9 +1417,6 @@ def _export(
):
assert GLOBALS.in_onnx_export is False
if export_type is None:
export_type = _exporter_states.ExportTypes.PROTOBUF_FILE
if isinstance(model, torch.nn.DataParallel):
raise ValueError(
"torch.nn.DataParallel is not supported by ONNX "
@ -1516,10 +1507,6 @@ def _export(
dynamic_axes=dynamic_axes,
)
# TODO: Don't allocate a in-memory string for the protobuf
defer_weight_export = (
export_type is not _exporter_states.ExportTypes.PROTOBUF_FILE
)
if custom_opsets is None:
custom_opsets = {}
@ -1540,6 +1527,7 @@ def _export(
getattr(model, "training", False), # type: ignore[arg-type]
)
_C._jit_pass_onnx_assign_scoped_names_for_node_and_value(graph)
defer_weight_export = False
if export_params:
(
proto,
@ -1569,7 +1557,7 @@ def _export(
{},
opset_version,
dynamic_axes,
False,
defer_weight_export,
operator_export_type,
not verbose,
val_keep_init_as_ip,
@ -1585,7 +1573,7 @@ def _export(
)
if verbose:
_C._jit_onnx_log("Exported graph: ", graph)
onnx_proto_utils._export_file(proto, f, export_type, export_map)
onnx_proto_utils._export_file(proto, f, export_map)
finally:
assert GLOBALS.in_onnx_export
GLOBALS.in_onnx_export = False

View File

@ -26,7 +26,7 @@ import numpy.typing as npt
import torch
import torch._C._onnx as _C_onnx
from torch import _C
from torch.onnx import _constants, _experimental, _exporter_states, utils
from torch.onnx import _constants, _experimental, utils
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import onnx_proto_utils
from torch.types import Number
@ -893,8 +893,7 @@ def verify_aten_graph(
graph, export_options, onnx_params_dict
)
model_f: str | io.BytesIO = io.BytesIO()
export_type = _exporter_states.ExportTypes.PROTOBUF_FILE
onnx_proto_utils._export_file(proto, model_f, export_type, export_map)
onnx_proto_utils._export_file(proto, model_f, export_map)
# NOTE: Verification is unstable. Try catch to emit information for debugging.
try: