mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[ONNX] Add Internal Utils: onnx_proto_utils.py for onnx/onnx-script/onnx_proto (#88376)
Added `onnx_proto_utils.py` for onnx/onnx-script related process. The idea is like jit_utils.py, and to simplify what we have in `torch/onnx/utils.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/88376 Approved by: https://github.com/justinchuby, https://github.com/BowenBao
This commit is contained in:
parent
f3af5ba48e
commit
c3acb9c885
143
torch/onnx/_internal/onnx_proto_utils.py
Normal file
143
torch/onnx/_internal/onnx_proto_utils.py
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
"""Utilities for manipulating the onnx and onnx-script dependencies and ONNX proto."""
|
||||
|
||||
import io
|
||||
import os
|
||||
import zipfile
|
||||
from typing import List, Mapping, Set, Union
|
||||
|
||||
import torch
|
||||
import torch.jit._trace
|
||||
import torch.serialization
|
||||
from torch.onnx import _constants, _exporter_states, errors
|
||||
from torch.onnx._internal import _beartype, jit_utils, registration
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _export_file(
|
||||
model_bytes: bytes,
|
||||
f: Union[io.BytesIO, str],
|
||||
export_type: str,
|
||||
export_map: Mapping[str, bytes],
|
||||
) -> None:
|
||||
"""export/write model bytes into directory/protobuf/zip"""
|
||||
# TODO(titaiwang) MYPY asks for os.PathLike[str] type for parameter: f,
|
||||
# but beartype raises beartype.roar.BeartypeDecorHintNonpepException,
|
||||
# as os.PathLike[str] uncheckable at runtime
|
||||
if export_type == _exporter_states.ExportTypes.PROTOBUF_FILE:
|
||||
assert len(export_map) == 0
|
||||
with torch.serialization._open_file_like(f, "wb") as opened_file:
|
||||
opened_file.write(model_bytes)
|
||||
elif export_type in {
|
||||
_exporter_states.ExportTypes.ZIP_ARCHIVE,
|
||||
_exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE,
|
||||
}:
|
||||
compression = (
|
||||
zipfile.ZIP_DEFLATED
|
||||
if export_type == _exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE
|
||||
else zipfile.ZIP_STORED
|
||||
)
|
||||
with zipfile.ZipFile(f, "w", compression=compression) as z:
|
||||
z.writestr(_constants.ONNX_ARCHIVE_MODEL_PROTO_NAME, model_bytes)
|
||||
for k, v in export_map.items():
|
||||
z.writestr(k, v)
|
||||
elif export_type == _exporter_states.ExportTypes.DIRECTORY:
|
||||
if isinstance(f, io.BytesIO) or not os.path.isdir(f): # type: ignore[arg-type]
|
||||
raise ValueError(
|
||||
f"f should be directory when export_type is set to DIRECTORY, instead get type(f): {type(f)}"
|
||||
)
|
||||
if not os.path.exists(f): # type: ignore[arg-type]
|
||||
os.makedirs(f) # type: ignore[arg-type]
|
||||
|
||||
model_proto_file = os.path.join(f, _constants.ONNX_ARCHIVE_MODEL_PROTO_NAME) # type: ignore[arg-type]
|
||||
with torch.serialization._open_file_like(model_proto_file, "wb") as opened_file:
|
||||
opened_file.write(model_bytes)
|
||||
|
||||
for k, v in export_map.items():
|
||||
weight_proto_file = os.path.join(f, k) # type: ignore[arg-type]
|
||||
with torch.serialization._open_file_like(
|
||||
weight_proto_file, "wb"
|
||||
) as opened_file:
|
||||
opened_file.write(v)
|
||||
else:
|
||||
raise ValueError("Unknown export type")
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _add_onnxscript_fn(
|
||||
model_bytes: bytes,
|
||||
custom_opsets: Mapping[str, int],
|
||||
) -> bytes:
|
||||
"""Insert model-included custom onnx-script function into ModelProto"""
|
||||
# TODO(titaiwang): remove this when onnx becomes dependency
|
||||
try:
|
||||
import onnx
|
||||
except ImportError:
|
||||
raise errors.OnnxExporterError("Module onnx is not installed!")
|
||||
|
||||
# For > 2GB model, onnx.load_fromstring would fail. However, because
|
||||
# in _export_onnx, the tensors should be saved separately if the proto
|
||||
# size > 2GB, and if it for some reason did not, the model would fail on
|
||||
# serialization anyway in terms of the protobuf limitation. So we don't
|
||||
# need to worry about > 2GB model getting here.
|
||||
model_proto = onnx.load_from_string(model_bytes)
|
||||
|
||||
# Iterate graph nodes to insert only the included custom
|
||||
# function_proto into model_proto
|
||||
# TODO(titaiwang): Currently, onnxscript doesn't support ONNXFunction
|
||||
# calling other ONNXFunction scenario, neither does it here
|
||||
onnx_function_list = list() # type: ignore[var-annotated]
|
||||
included_node_func = set() # type: Set[str]
|
||||
# onnx_function_list and included_node_func are expanded in-place
|
||||
_find_onnxscript_op(
|
||||
model_proto.graph, included_node_func, custom_opsets, onnx_function_list
|
||||
)
|
||||
|
||||
if onnx_function_list:
|
||||
model_proto.functions.extend(onnx_function_list)
|
||||
model_bytes = model_proto.SerializeToString()
|
||||
return model_bytes
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _find_onnxscript_op(
|
||||
graph_proto,
|
||||
included_node_func: Set[str],
|
||||
custom_opsets: Mapping[str, int],
|
||||
onnx_function_list: List,
|
||||
):
|
||||
"""Recursively iterate ModelProto to find ONNXFunction op as it may contain control flow Op."""
|
||||
for node in graph_proto.node:
|
||||
node_kind = node.domain + "::" + node.op_type
|
||||
# Recursive needed for control flow nodes: IF/Loop which has inner graph_proto
|
||||
for attr in node.attribute:
|
||||
if attr.g is not None:
|
||||
_find_onnxscript_op(
|
||||
attr.g, included_node_func, custom_opsets, onnx_function_list
|
||||
)
|
||||
# Only custom Op with ONNX function and aten with symbolic_fn should be found in registry
|
||||
onnx_function_group = registration.registry.get_function_group(node_kind)
|
||||
# Ruled out corner cases: onnx/prim in registry
|
||||
if (
|
||||
node.domain
|
||||
and not jit_utils.is_aten(node.domain)
|
||||
and not jit_utils.is_prim(node.domain)
|
||||
and not jit_utils.is_onnx(node.domain)
|
||||
and onnx_function_group is not None
|
||||
and node_kind not in included_node_func
|
||||
):
|
||||
specified_version = custom_opsets.get(node.domain, 1)
|
||||
onnx_fn = onnx_function_group.get(specified_version)
|
||||
if onnx_fn is not None:
|
||||
# TODO(titaiwang): to_function_proto is onnx-script API and can be annotated
|
||||
# after onnx-script is dependency
|
||||
onnx_function_list.append(onnx_fn.to_function_proto()) # type: ignore[attr-defined]
|
||||
included_node_func.add(node_kind)
|
||||
continue
|
||||
raise errors.UnsupportedOperatorError(
|
||||
node_kind,
|
||||
specified_version,
|
||||
onnx_function_group.get_min_supported()
|
||||
if onnx_function_group
|
||||
else None,
|
||||
)
|
||||
return onnx_function_list, included_node_func
|
||||
|
|
@ -9,12 +9,10 @@ import contextlib
|
|||
import copy
|
||||
import inspect
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import textwrap
|
||||
import typing
|
||||
import warnings
|
||||
import zipfile
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
|
|
@ -38,15 +36,19 @@ import torch.serialization
|
|||
from torch import _C
|
||||
from torch.onnx import ( # noqa: F401
|
||||
_constants,
|
||||
_deprecation,
|
||||
_exporter_states,
|
||||
_patch_torch,
|
||||
errors,
|
||||
symbolic_caffe2,
|
||||
symbolic_helper,
|
||||
)
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import _beartype, diagnostics, jit_utils, registration
|
||||
from torch.onnx._internal import (
|
||||
_beartype,
|
||||
diagnostics,
|
||||
jit_utils,
|
||||
onnx_proto_utils,
|
||||
registration,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"is_in_onnx_export",
|
||||
|
|
@ -1598,13 +1600,13 @@ def _export(
|
|||
node_attr_to_name,
|
||||
)
|
||||
# insert function_proto into model_proto.
|
||||
proto = _add_onnxscript_fn(
|
||||
proto = onnx_proto_utils._add_onnxscript_fn(
|
||||
proto,
|
||||
custom_opsets,
|
||||
)
|
||||
if verbose:
|
||||
torch.onnx.log("Exported graph: ", graph)
|
||||
_export_file(proto, f, export_type, export_map)
|
||||
onnx_proto_utils._export_file(proto, f, export_type, export_map)
|
||||
# The ONNX checker only works for ONNX graph. So if the operator_export_type is not ONNX,
|
||||
# we can skip this check.
|
||||
# If large model format export is enabled, proto will only contain data location instead of
|
||||
|
|
@ -1625,138 +1627,6 @@ def _export(
|
|||
return torch_out
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _export_file(
|
||||
model_bytes: bytes,
|
||||
f: Union[io.BytesIO, str],
|
||||
export_type: str,
|
||||
export_map: Mapping[str, bytes],
|
||||
) -> None:
|
||||
"""export/write model bytes into directory/protobuf/zip"""
|
||||
# TODO(titaiwang) MYPY asks for os.PathLike[str] type for parameter: f,
|
||||
# but beartype raises beartype.roar.BeartypeDecorHintNonpepException,
|
||||
# as os.PathLike[str] uncheckable at runtime
|
||||
if export_type == _exporter_states.ExportTypes.PROTOBUF_FILE:
|
||||
assert len(export_map) == 0
|
||||
with torch.serialization._open_file_like(f, "wb") as opened_file:
|
||||
opened_file.write(model_bytes)
|
||||
elif export_type in [
|
||||
_exporter_states.ExportTypes.ZIP_ARCHIVE,
|
||||
_exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE,
|
||||
]:
|
||||
compression = (
|
||||
zipfile.ZIP_DEFLATED
|
||||
if export_type == _exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE
|
||||
else zipfile.ZIP_STORED
|
||||
)
|
||||
with zipfile.ZipFile(f, "w", compression=compression) as z:
|
||||
z.writestr(_constants.ONNX_ARCHIVE_MODEL_PROTO_NAME, model_bytes)
|
||||
for k, v in export_map.items():
|
||||
z.writestr(k, v)
|
||||
elif export_type == _exporter_states.ExportTypes.DIRECTORY:
|
||||
if isinstance(f, io.BytesIO) or not os.path.isdir(f): # type: ignore[arg-type]
|
||||
raise ValueError(
|
||||
f"f should be directory when export_type is set to DIRECTORY, instead get type(f): {type(f)}"
|
||||
)
|
||||
if not os.path.exists(f): # type: ignore[arg-type]
|
||||
os.makedirs(f) # type: ignore[arg-type]
|
||||
|
||||
model_proto_file = os.path.join(f, _constants.ONNX_ARCHIVE_MODEL_PROTO_NAME) # type: ignore[arg-type]
|
||||
with torch.serialization._open_file_like(model_proto_file, "wb") as opened_file:
|
||||
opened_file.write(model_bytes)
|
||||
|
||||
for k, v in export_map.items():
|
||||
weight_proto_file = os.path.join(f, k) # type: ignore[arg-type]
|
||||
with torch.serialization._open_file_like(
|
||||
weight_proto_file, "wb"
|
||||
) as opened_file:
|
||||
opened_file.write(v)
|
||||
else:
|
||||
raise RuntimeError("Unknown export type")
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _add_onnxscript_fn(
|
||||
model_bytes: bytes,
|
||||
custom_opsets: Mapping[str, int],
|
||||
) -> bytes:
|
||||
"""Insert model-included custom onnx-script function into ModelProto"""
|
||||
|
||||
# TODO(titaiwang): remove this when onnx becomes dependency
|
||||
try:
|
||||
import onnx
|
||||
except ImportError:
|
||||
raise errors.OnnxExporterError("Module onnx is not installed!")
|
||||
|
||||
# For > 2GB model, onnx.load_fromstring would fail. However, because
|
||||
# in _export_onnx, the tensors should be saved separately if the proto
|
||||
# size > 2GB, and if it for some reason did not, the model would fail on
|
||||
# serialization anyway in terms of the protobuf limitation. So we don't
|
||||
# need to worry about > 2GB model getting here.
|
||||
model_proto = onnx.load_from_string(model_bytes)
|
||||
|
||||
# Iterate graph nodes to insert only the included custom
|
||||
# function_proto into model_proto
|
||||
# TODO(titaiwang): Currently, onnxscript doesn't support ONNXFunction
|
||||
# calling other ONNXFunction scenario, neither does it here
|
||||
onnx_function_list = list() # type: ignore[var-annotated]
|
||||
included_node_func = set() # type: Set[str]
|
||||
# onnx_function_list and included_node_func are expanded in-place
|
||||
_find_onnxscript_op(
|
||||
model_proto.graph, included_node_func, custom_opsets, onnx_function_list
|
||||
)
|
||||
|
||||
if onnx_function_list:
|
||||
model_proto.functions.extend(onnx_function_list)
|
||||
model_bytes = model_proto.SerializeToString()
|
||||
return model_bytes
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _find_onnxscript_op(
|
||||
graph_proto,
|
||||
included_node_func: Set[str],
|
||||
custom_opsets: Mapping[str, int],
|
||||
onnx_function_list: List,
|
||||
):
|
||||
"""Recursively iterate ModelProto to find ONNXFunction op as it may contain control flow Op."""
|
||||
for node in graph_proto.node:
|
||||
node_kind = node.domain + "::" + node.op_type
|
||||
# Recursive is needed for control flow nodes: IF/Loop which has inner graph_proto
|
||||
for attr in node.attribute:
|
||||
if attr.g is not None:
|
||||
_find_onnxscript_op(
|
||||
attr.g, included_node_func, custom_opsets, onnx_function_list
|
||||
)
|
||||
# Only custom Op with ONNX function and aten with symbolic_fn should be found in registry
|
||||
onnx_function_group = registration.registry.get_function_group(node_kind)
|
||||
# Ruled out corner cases: onnx/prim in registry
|
||||
if (
|
||||
node.domain
|
||||
and not jit_utils.is_aten(node.domain)
|
||||
and not jit_utils.is_prim(node.domain)
|
||||
and not jit_utils.is_onnx(node.domain)
|
||||
and onnx_function_group is not None
|
||||
and node_kind not in included_node_func
|
||||
):
|
||||
specified_version = custom_opsets.get(node.domain, 1)
|
||||
onnx_fn = onnx_function_group.get(specified_version)
|
||||
if onnx_fn is not None:
|
||||
# TODO(titaiwang): to_function_proto is onnx-script API and can be annotated
|
||||
# after onnx-script is dependency
|
||||
onnx_function_list.append(onnx_fn.to_function_proto()) # type: ignore[attr-defined]
|
||||
included_node_func.add(node_kind)
|
||||
continue
|
||||
raise errors.UnsupportedOperatorError(
|
||||
node_kind,
|
||||
specified_version,
|
||||
onnx_function_group.get_min_supported()
|
||||
if onnx_function_group
|
||||
else None,
|
||||
)
|
||||
return onnx_function_list, included_node_func
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _apply_friendly_debug_names(graph, params):
|
||||
for n in graph.nodes():
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user