[ONNX] Add Internal Utils: onnx_proto_utils.py for onnx/onnx-script/onnx_proto (#88376)

Added `onnx_proto_utils.py` for onnx/onnx-script related process. The idea is like jit_utils.py, and to simplify what we have in `torch/onnx/utils.py`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88376
Approved by: https://github.com/justinchuby, https://github.com/BowenBao
This commit is contained in:
AllenTiTaiWang 2022-11-16 19:50:02 +00:00 committed by PyTorch MergeBot
parent f3af5ba48e
commit c3acb9c885
2 changed files with 152 additions and 139 deletions

View File

@ -0,0 +1,143 @@
"""Utilities for manipulating the onnx and onnx-script dependencies and ONNX proto."""
import io
import os
import zipfile
from typing import List, Mapping, Set, Union
import torch
import torch.jit._trace
import torch.serialization
from torch.onnx import _constants, _exporter_states, errors
from torch.onnx._internal import _beartype, jit_utils, registration
@_beartype.beartype
def _export_file(
model_bytes: bytes,
f: Union[io.BytesIO, str],
export_type: str,
export_map: Mapping[str, bytes],
) -> None:
"""export/write model bytes into directory/protobuf/zip"""
# TODO(titaiwang) MYPY asks for os.PathLike[str] type for parameter: f,
# but beartype raises beartype.roar.BeartypeDecorHintNonpepException,
# as os.PathLike[str] uncheckable at runtime
if export_type == _exporter_states.ExportTypes.PROTOBUF_FILE:
assert len(export_map) == 0
with torch.serialization._open_file_like(f, "wb") as opened_file:
opened_file.write(model_bytes)
elif export_type in {
_exporter_states.ExportTypes.ZIP_ARCHIVE,
_exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE,
}:
compression = (
zipfile.ZIP_DEFLATED
if export_type == _exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE
else zipfile.ZIP_STORED
)
with zipfile.ZipFile(f, "w", compression=compression) as z:
z.writestr(_constants.ONNX_ARCHIVE_MODEL_PROTO_NAME, model_bytes)
for k, v in export_map.items():
z.writestr(k, v)
elif export_type == _exporter_states.ExportTypes.DIRECTORY:
if isinstance(f, io.BytesIO) or not os.path.isdir(f): # type: ignore[arg-type]
raise ValueError(
f"f should be directory when export_type is set to DIRECTORY, instead get type(f): {type(f)}"
)
if not os.path.exists(f): # type: ignore[arg-type]
os.makedirs(f) # type: ignore[arg-type]
model_proto_file = os.path.join(f, _constants.ONNX_ARCHIVE_MODEL_PROTO_NAME) # type: ignore[arg-type]
with torch.serialization._open_file_like(model_proto_file, "wb") as opened_file:
opened_file.write(model_bytes)
for k, v in export_map.items():
weight_proto_file = os.path.join(f, k) # type: ignore[arg-type]
with torch.serialization._open_file_like(
weight_proto_file, "wb"
) as opened_file:
opened_file.write(v)
else:
raise ValueError("Unknown export type")
@_beartype.beartype
def _add_onnxscript_fn(
model_bytes: bytes,
custom_opsets: Mapping[str, int],
) -> bytes:
"""Insert model-included custom onnx-script function into ModelProto"""
# TODO(titaiwang): remove this when onnx becomes dependency
try:
import onnx
except ImportError:
raise errors.OnnxExporterError("Module onnx is not installed!")
# For > 2GB model, onnx.load_fromstring would fail. However, because
# in _export_onnx, the tensors should be saved separately if the proto
# size > 2GB, and if it for some reason did not, the model would fail on
# serialization anyway in terms of the protobuf limitation. So we don't
# need to worry about > 2GB model getting here.
model_proto = onnx.load_from_string(model_bytes)
# Iterate graph nodes to insert only the included custom
# function_proto into model_proto
# TODO(titaiwang): Currently, onnxscript doesn't support ONNXFunction
# calling other ONNXFunction scenario, neither does it here
onnx_function_list = list() # type: ignore[var-annotated]
included_node_func = set() # type: Set[str]
# onnx_function_list and included_node_func are expanded in-place
_find_onnxscript_op(
model_proto.graph, included_node_func, custom_opsets, onnx_function_list
)
if onnx_function_list:
model_proto.functions.extend(onnx_function_list)
model_bytes = model_proto.SerializeToString()
return model_bytes
@_beartype.beartype
def _find_onnxscript_op(
graph_proto,
included_node_func: Set[str],
custom_opsets: Mapping[str, int],
onnx_function_list: List,
):
"""Recursively iterate ModelProto to find ONNXFunction op as it may contain control flow Op."""
for node in graph_proto.node:
node_kind = node.domain + "::" + node.op_type
# Recursive needed for control flow nodes: IF/Loop which has inner graph_proto
for attr in node.attribute:
if attr.g is not None:
_find_onnxscript_op(
attr.g, included_node_func, custom_opsets, onnx_function_list
)
# Only custom Op with ONNX function and aten with symbolic_fn should be found in registry
onnx_function_group = registration.registry.get_function_group(node_kind)
# Ruled out corner cases: onnx/prim in registry
if (
node.domain
and not jit_utils.is_aten(node.domain)
and not jit_utils.is_prim(node.domain)
and not jit_utils.is_onnx(node.domain)
and onnx_function_group is not None
and node_kind not in included_node_func
):
specified_version = custom_opsets.get(node.domain, 1)
onnx_fn = onnx_function_group.get(specified_version)
if onnx_fn is not None:
# TODO(titaiwang): to_function_proto is onnx-script API and can be annotated
# after onnx-script is dependency
onnx_function_list.append(onnx_fn.to_function_proto()) # type: ignore[attr-defined]
included_node_func.add(node_kind)
continue
raise errors.UnsupportedOperatorError(
node_kind,
specified_version,
onnx_function_group.get_min_supported()
if onnx_function_group
else None,
)
return onnx_function_list, included_node_func

View File

@ -9,12 +9,10 @@ import contextlib
import copy
import inspect
import io
import os
import re
import textwrap
import typing
import warnings
import zipfile
from typing import (
Any,
Callable,
@ -38,15 +36,19 @@ import torch.serialization
from torch import _C
from torch.onnx import ( # noqa: F401
_constants,
_deprecation,
_exporter_states,
_patch_torch,
errors,
symbolic_caffe2,
symbolic_helper,
)
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import _beartype, diagnostics, jit_utils, registration
from torch.onnx._internal import (
_beartype,
diagnostics,
jit_utils,
onnx_proto_utils,
registration,
)
__all__ = [
"is_in_onnx_export",
@ -1598,13 +1600,13 @@ def _export(
node_attr_to_name,
)
# insert function_proto into model_proto.
proto = _add_onnxscript_fn(
proto = onnx_proto_utils._add_onnxscript_fn(
proto,
custom_opsets,
)
if verbose:
torch.onnx.log("Exported graph: ", graph)
_export_file(proto, f, export_type, export_map)
onnx_proto_utils._export_file(proto, f, export_type, export_map)
# The ONNX checker only works for ONNX graph. So if the operator_export_type is not ONNX,
# we can skip this check.
# If large model format export is enabled, proto will only contain data location instead of
@ -1625,138 +1627,6 @@ def _export(
return torch_out
@_beartype.beartype
def _export_file(
model_bytes: bytes,
f: Union[io.BytesIO, str],
export_type: str,
export_map: Mapping[str, bytes],
) -> None:
"""export/write model bytes into directory/protobuf/zip"""
# TODO(titaiwang) MYPY asks for os.PathLike[str] type for parameter: f,
# but beartype raises beartype.roar.BeartypeDecorHintNonpepException,
# as os.PathLike[str] uncheckable at runtime
if export_type == _exporter_states.ExportTypes.PROTOBUF_FILE:
assert len(export_map) == 0
with torch.serialization._open_file_like(f, "wb") as opened_file:
opened_file.write(model_bytes)
elif export_type in [
_exporter_states.ExportTypes.ZIP_ARCHIVE,
_exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE,
]:
compression = (
zipfile.ZIP_DEFLATED
if export_type == _exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE
else zipfile.ZIP_STORED
)
with zipfile.ZipFile(f, "w", compression=compression) as z:
z.writestr(_constants.ONNX_ARCHIVE_MODEL_PROTO_NAME, model_bytes)
for k, v in export_map.items():
z.writestr(k, v)
elif export_type == _exporter_states.ExportTypes.DIRECTORY:
if isinstance(f, io.BytesIO) or not os.path.isdir(f): # type: ignore[arg-type]
raise ValueError(
f"f should be directory when export_type is set to DIRECTORY, instead get type(f): {type(f)}"
)
if not os.path.exists(f): # type: ignore[arg-type]
os.makedirs(f) # type: ignore[arg-type]
model_proto_file = os.path.join(f, _constants.ONNX_ARCHIVE_MODEL_PROTO_NAME) # type: ignore[arg-type]
with torch.serialization._open_file_like(model_proto_file, "wb") as opened_file:
opened_file.write(model_bytes)
for k, v in export_map.items():
weight_proto_file = os.path.join(f, k) # type: ignore[arg-type]
with torch.serialization._open_file_like(
weight_proto_file, "wb"
) as opened_file:
opened_file.write(v)
else:
raise RuntimeError("Unknown export type")
@_beartype.beartype
def _add_onnxscript_fn(
model_bytes: bytes,
custom_opsets: Mapping[str, int],
) -> bytes:
"""Insert model-included custom onnx-script function into ModelProto"""
# TODO(titaiwang): remove this when onnx becomes dependency
try:
import onnx
except ImportError:
raise errors.OnnxExporterError("Module onnx is not installed!")
# For > 2GB model, onnx.load_fromstring would fail. However, because
# in _export_onnx, the tensors should be saved separately if the proto
# size > 2GB, and if it for some reason did not, the model would fail on
# serialization anyway in terms of the protobuf limitation. So we don't
# need to worry about > 2GB model getting here.
model_proto = onnx.load_from_string(model_bytes)
# Iterate graph nodes to insert only the included custom
# function_proto into model_proto
# TODO(titaiwang): Currently, onnxscript doesn't support ONNXFunction
# calling other ONNXFunction scenario, neither does it here
onnx_function_list = list() # type: ignore[var-annotated]
included_node_func = set() # type: Set[str]
# onnx_function_list and included_node_func are expanded in-place
_find_onnxscript_op(
model_proto.graph, included_node_func, custom_opsets, onnx_function_list
)
if onnx_function_list:
model_proto.functions.extend(onnx_function_list)
model_bytes = model_proto.SerializeToString()
return model_bytes
@_beartype.beartype
def _find_onnxscript_op(
graph_proto,
included_node_func: Set[str],
custom_opsets: Mapping[str, int],
onnx_function_list: List,
):
"""Recursively iterate ModelProto to find ONNXFunction op as it may contain control flow Op."""
for node in graph_proto.node:
node_kind = node.domain + "::" + node.op_type
# Recursive is needed for control flow nodes: IF/Loop which has inner graph_proto
for attr in node.attribute:
if attr.g is not None:
_find_onnxscript_op(
attr.g, included_node_func, custom_opsets, onnx_function_list
)
# Only custom Op with ONNX function and aten with symbolic_fn should be found in registry
onnx_function_group = registration.registry.get_function_group(node_kind)
# Ruled out corner cases: onnx/prim in registry
if (
node.domain
and not jit_utils.is_aten(node.domain)
and not jit_utils.is_prim(node.domain)
and not jit_utils.is_onnx(node.domain)
and onnx_function_group is not None
and node_kind not in included_node_func
):
specified_version = custom_opsets.get(node.domain, 1)
onnx_fn = onnx_function_group.get(specified_version)
if onnx_fn is not None:
# TODO(titaiwang): to_function_proto is onnx-script API and can be annotated
# after onnx-script is dependency
onnx_function_list.append(onnx_fn.to_function_proto()) # type: ignore[attr-defined]
included_node_func.add(node_kind)
continue
raise errors.UnsupportedOperatorError(
node_kind,
specified_version,
onnx_function_group.get_min_supported()
if onnx_function_group
else None,
)
return onnx_function_list, included_node_func
@_beartype.beartype
def _apply_friendly_debug_names(graph, params):
for n in graph.nodes():