mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Remove ExportTypes (#137789)
Remove deprecated ExportTypes and the `_exporter_states` module. Only protobuf (default) is supported going forward. Differential Revision: [D64412947](https://our.internmc.facebook.com/intern/diff/D64412947) Pull Request resolved: https://github.com/pytorch/pytorch/pull/137789 Approved by: https://github.com/titaiwangms, https://github.com/xadupre
This commit is contained in:
parent
af0bc75460
commit
6e38c87ad0
|
|
@ -24,7 +24,6 @@ __all__ = [
|
|||
"symbolic_opset19",
|
||||
"symbolic_opset20",
|
||||
# Enums
|
||||
"ExportTypes",
|
||||
"OperatorExportTypes",
|
||||
"TrainingMode",
|
||||
"TensorProtoDataType",
|
||||
|
|
@ -57,7 +56,6 @@ from torch import _C
|
|||
from torch._C import _onnx as _C_onnx
|
||||
from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode
|
||||
|
||||
from ._exporter_states import ExportTypes
|
||||
from ._internal.exporter._onnx_program import ONNXProgram
|
||||
from ._internal.onnxruntime import (
|
||||
is_onnxrt_backend_supported,
|
||||
|
|
@ -115,7 +113,6 @@ if TYPE_CHECKING:
|
|||
# Set namespace for exposed private names
|
||||
DiagnosticOptions.__module__ = "torch.onnx"
|
||||
ExportOptions.__module__ = "torch.onnx"
|
||||
ExportTypes.__module__ = "torch.onnx"
|
||||
JitScalarType.__module__ = "torch.onnx"
|
||||
ONNXProgram.__module__ = "torch.onnx"
|
||||
ONNXRuntimeOptions.__module__ = "torch.onnx"
|
||||
|
|
|
|||
|
|
@ -4,19 +4,21 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import glob
|
||||
import io
|
||||
import os
|
||||
import shutil
|
||||
import zipfile
|
||||
from typing import Any, Mapping
|
||||
from typing import Any, Mapping, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.jit._trace
|
||||
import torch.serialization
|
||||
from torch.onnx import _constants, _exporter_states, errors
|
||||
from torch.onnx import errors
|
||||
from torch.onnx._internal import jit_utils, registration
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import io
|
||||
|
||||
|
||||
def export_as_test_case(
|
||||
model_bytes: bytes, inputs_data, outputs_data, name: str, dir: str
|
||||
) -> str:
|
||||
|
|
@ -54,7 +56,6 @@ def export_as_test_case(
|
|||
_export_file(
|
||||
model_bytes,
|
||||
os.path.join(test_case_dir, "model.onnx"),
|
||||
_exporter_states.ExportTypes.PROTOBUF_FILE,
|
||||
{},
|
||||
)
|
||||
data_set_dir = os.path.join(test_case_dir, "test_data_set_0")
|
||||
|
|
@ -163,47 +164,12 @@ def export_data(data, value_info_proto, f: str) -> None:
|
|||
def _export_file(
|
||||
model_bytes: bytes,
|
||||
f: io.BytesIO | str,
|
||||
export_type: str,
|
||||
export_map: Mapping[str, bytes],
|
||||
) -> None:
|
||||
"""export/write model bytes into directory/protobuf/zip"""
|
||||
if export_type == _exporter_states.ExportTypes.PROTOBUF_FILE:
|
||||
assert len(export_map) == 0
|
||||
with torch.serialization._open_file_like(f, "wb") as opened_file:
|
||||
opened_file.write(model_bytes)
|
||||
elif export_type in {
|
||||
_exporter_states.ExportTypes.ZIP_ARCHIVE,
|
||||
_exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE,
|
||||
}:
|
||||
compression = (
|
||||
zipfile.ZIP_DEFLATED
|
||||
if export_type == _exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE
|
||||
else zipfile.ZIP_STORED
|
||||
)
|
||||
with zipfile.ZipFile(f, "w", compression=compression) as z:
|
||||
z.writestr(_constants.ONNX_ARCHIVE_MODEL_PROTO_NAME, model_bytes)
|
||||
for k, v in export_map.items():
|
||||
z.writestr(k, v)
|
||||
elif export_type == _exporter_states.ExportTypes.DIRECTORY:
|
||||
if isinstance(f, io.BytesIO) or not os.path.isdir(f): # type: ignore[arg-type]
|
||||
raise ValueError(
|
||||
f"f should be directory when export_type is set to DIRECTORY, instead get type(f): {type(f)}"
|
||||
)
|
||||
if not os.path.exists(f): # type: ignore[arg-type]
|
||||
os.makedirs(f) # type: ignore[arg-type]
|
||||
|
||||
model_proto_file = os.path.join(f, _constants.ONNX_ARCHIVE_MODEL_PROTO_NAME) # type: ignore[arg-type]
|
||||
with torch.serialization._open_file_like(model_proto_file, "wb") as opened_file:
|
||||
opened_file.write(model_bytes)
|
||||
|
||||
for k, v in export_map.items():
|
||||
weight_proto_file = os.path.join(f, k) # type: ignore[arg-type]
|
||||
with torch.serialization._open_file_like(
|
||||
weight_proto_file, "wb"
|
||||
) as opened_file:
|
||||
opened_file.write(v)
|
||||
else:
|
||||
raise ValueError("Unknown export type")
|
||||
assert len(export_map) == 0
|
||||
with torch.serialization._open_file_like(f, "wb") as opened_file:
|
||||
opened_file.write(model_bytes)
|
||||
|
||||
|
||||
def _add_onnxscript_fn(
|
||||
|
|
|
|||
|
|
@ -20,13 +20,7 @@ import torch._C._onnx as _C_onnx
|
|||
import torch.jit._trace
|
||||
import torch.serialization
|
||||
from torch import _C
|
||||
from torch.onnx import ( # noqa: F401
|
||||
_constants,
|
||||
_deprecation,
|
||||
_exporter_states,
|
||||
errors,
|
||||
symbolic_helper,
|
||||
)
|
||||
from torch.onnx import _constants, _deprecation, errors, symbolic_helper # noqa: F401
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import diagnostics, jit_utils, onnx_proto_utils, registration
|
||||
|
||||
|
|
@ -1423,9 +1417,6 @@ def _export(
|
|||
):
|
||||
assert GLOBALS.in_onnx_export is False
|
||||
|
||||
if export_type is None:
|
||||
export_type = _exporter_states.ExportTypes.PROTOBUF_FILE
|
||||
|
||||
if isinstance(model, torch.nn.DataParallel):
|
||||
raise ValueError(
|
||||
"torch.nn.DataParallel is not supported by ONNX "
|
||||
|
|
@ -1516,10 +1507,6 @@ def _export(
|
|||
dynamic_axes=dynamic_axes,
|
||||
)
|
||||
|
||||
# TODO: Don't allocate a in-memory string for the protobuf
|
||||
defer_weight_export = (
|
||||
export_type is not _exporter_states.ExportTypes.PROTOBUF_FILE
|
||||
)
|
||||
if custom_opsets is None:
|
||||
custom_opsets = {}
|
||||
|
||||
|
|
@ -1540,6 +1527,7 @@ def _export(
|
|||
getattr(model, "training", False), # type: ignore[arg-type]
|
||||
)
|
||||
_C._jit_pass_onnx_assign_scoped_names_for_node_and_value(graph)
|
||||
defer_weight_export = False
|
||||
if export_params:
|
||||
(
|
||||
proto,
|
||||
|
|
@ -1569,7 +1557,7 @@ def _export(
|
|||
{},
|
||||
opset_version,
|
||||
dynamic_axes,
|
||||
False,
|
||||
defer_weight_export,
|
||||
operator_export_type,
|
||||
not verbose,
|
||||
val_keep_init_as_ip,
|
||||
|
|
@ -1585,7 +1573,7 @@ def _export(
|
|||
)
|
||||
if verbose:
|
||||
_C._jit_onnx_log("Exported graph: ", graph)
|
||||
onnx_proto_utils._export_file(proto, f, export_type, export_map)
|
||||
onnx_proto_utils._export_file(proto, f, export_map)
|
||||
finally:
|
||||
assert GLOBALS.in_onnx_export
|
||||
GLOBALS.in_onnx_export = False
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ import numpy.typing as npt
|
|||
import torch
|
||||
import torch._C._onnx as _C_onnx
|
||||
from torch import _C
|
||||
from torch.onnx import _constants, _experimental, _exporter_states, utils
|
||||
from torch.onnx import _constants, _experimental, utils
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import onnx_proto_utils
|
||||
from torch.types import Number
|
||||
|
|
@ -893,8 +893,7 @@ def verify_aten_graph(
|
|||
graph, export_options, onnx_params_dict
|
||||
)
|
||||
model_f: str | io.BytesIO = io.BytesIO()
|
||||
export_type = _exporter_states.ExportTypes.PROTOBUF_FILE
|
||||
onnx_proto_utils._export_file(proto, model_f, export_type, export_map)
|
||||
onnx_proto_utils._export_file(proto, model_f, export_map)
|
||||
|
||||
# NOTE: Verification is unstable. Try catch to emit information for debugging.
|
||||
try:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user