mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103942 Approved by: https://github.com/justinchuby, https://github.com/thiagocrepaldi
158 lines
6.7 KiB
Python
158 lines
6.7 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
from typing import Tuple, TYPE_CHECKING
|
|
|
|
import torch
|
|
from torch.onnx import _type_utils as jit_type_utils
|
|
from torch.onnx._internal import _beartype
|
|
|
|
if TYPE_CHECKING:
|
|
import onnx
|
|
|
|
|
|
@_beartype.beartype
|
|
def _create_tensor_proto_with_external_data(
|
|
tensor: torch.Tensor, name: str, location: str, basepath: str
|
|
) -> onnx.TensorProto:
|
|
"""Create a TensorProto with external data from a PyTorch tensor.
|
|
The external data is saved to os.path.join(basepath, location).
|
|
|
|
Args:
|
|
tensor: Tensor to be saved.
|
|
name: Name of the tensor (i.e., initializer name in ONNX graph).
|
|
location: Relative location of the external data file
|
|
(e.g., "/tmp/initializers/weight_0" when model is "/tmp/model_name.onnx").
|
|
basepath: Base path of the external data file (e.g., "/tmp/external_data" while model must be in "/tmp").
|
|
|
|
|
|
Reference for ONNX's external data format:
|
|
How to load?
|
|
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L187
|
|
How to save?
|
|
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L43
|
|
How to set ONNX fields?
|
|
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L88
|
|
"""
|
|
# FIXME: Avoid importing onnx into torch.onnx.
|
|
import onnx
|
|
|
|
tensor_proto = onnx.TensorProto()
|
|
tensor_proto.name = name
|
|
tensor_proto.data_type = jit_type_utils.JitScalarType.from_dtype(
|
|
tensor.dtype
|
|
).onnx_type()
|
|
tensor_proto.dims.extend(tensor.shape)
|
|
tensor_proto.data_location = onnx.TensorProto.EXTERNAL
|
|
|
|
# Settings for saving one tensor per file.
|
|
# Offset is zero because there is no other tensor in the same file.
|
|
key_value_pairs = {
|
|
"location": location,
|
|
"offset": 0,
|
|
"length": tensor.untyped_storage().nbytes(),
|
|
}
|
|
for k, v in key_value_pairs.items():
|
|
entry = tensor_proto.external_data.add()
|
|
entry.key = k
|
|
entry.value = str(v)
|
|
|
|
# Actual path to write content of tensor.
|
|
external_data_file_path = os.path.join(basepath, location)
|
|
if os.path.exists(external_data_file_path):
|
|
os.remove(external_data_file_path)
|
|
|
|
# Create external data's folder if not exists.
|
|
external_data_dir_path = os.path.dirname(external_data_file_path)
|
|
if not os.path.exists(external_data_dir_path):
|
|
# if the demo_folder directory is not present
|
|
# then create it.
|
|
os.makedirs(external_data_dir_path)
|
|
|
|
# Create a fresh file.
|
|
with open(external_data_file_path, "xb") as data_file:
|
|
# No need to call "seek" because offset is 0.
|
|
# data_file.seek(0)
|
|
# Write tensor content to the file.
|
|
data_file.write(tensor.numpy().tobytes())
|
|
|
|
return tensor_proto
|
|
|
|
|
|
@_beartype.beartype
|
|
def save_model_with_external_data(
|
|
basepath: str,
|
|
model_location: str,
|
|
initializer_location: str,
|
|
torch_load_paths: Tuple[str, ...],
|
|
onnx_model: onnx.ModelProto,
|
|
) -> None:
|
|
"""Load PyTorch tensors from files and add to "onnx_model" as external initializers.
|
|
|
|
Output files:
|
|
ONNX model file path:
|
|
ONNX initializer folder: os.path.join(basepath, initializer_location)
|
|
|
|
After running this function, you can do
|
|
ort_sess = onnxruntime.InferenceSession(os.path.join(basepath, model_location))
|
|
to execute the model.
|
|
|
|
Arguments:
|
|
basepath: Base path of the external data file (e.g., "/tmp/large-onnx-model").
|
|
model_location: Relative location of the ONNX model file.
|
|
E.g., "model.onnx" so that the model file is saved to
|
|
"/tmp/large-onnx-model/model.onnx".
|
|
initializer_location: Relative location of the ONNX initializer folder.
|
|
E.g., "initializers" so that the initializers are saved to
|
|
"/tmp/large-onnx-model/initializers".
|
|
torch_load_paths: Files which containing serialized PyTorch tensors to be saved
|
|
as ONNX initializers. They are loaded by torch.load.
|
|
onnx_model: ONNX model to be saved with external initializers.
|
|
If an input name matches a tensor loaded from "torch_load_paths",
|
|
the tensor will be saved as that input's external initializer.
|
|
"""
|
|
# FIXME: Avoid importing onnx into torch.onnx.
|
|
import onnx
|
|
|
|
onnx_model_with_initializers = onnx.ModelProto()
|
|
onnx_model_with_initializers.CopyFrom(onnx_model)
|
|
onnx_input_names = [input.name for input in onnx_model.graph.input]
|
|
|
|
for path in torch_load_paths:
|
|
state_ditc = torch.load(path)
|
|
for name, tensor in state_ditc.items():
|
|
# Basically, "transformer.attention.self.query.weight" is mapped
|
|
# to "transformer_attention_self_query_weight" for mimicking the
|
|
# name-modifying code in FX-to-ONNX exporter.
|
|
# See function _replace_get_attr_with_placeholder for details.
|
|
refined_name = name.replace(".", "_")
|
|
|
|
# For each refined PyTorch tensor name loaded by torch.load,
|
|
# 1. Search its best match in ONNX model. E.g., the match of
|
|
# "transformer_attention_weight" could be "attention_weight".
|
|
# 2. Set "tensor" as the initializer of the matched ONNX input.
|
|
# E.g., "tensor" is stored as the initializer of "attention_weight".
|
|
# Step 1 is required because sometimes, tensor names are stored with prefix the dictionary
|
|
# loaded by torch.load.
|
|
for onnx_input_name in onnx_input_names:
|
|
if onnx_input_name.endswith(refined_name) or refined_name.endswith(
|
|
onnx_input_name
|
|
):
|
|
# Find a match. Change refined_name to the matched ONNX input name, so that we
|
|
# create initializer with the right ONNX name.
|
|
refined_name = onnx_input_name
|
|
break
|
|
|
|
relative_tensor_file_path = os.path.join(initializer_location, refined_name)
|
|
# Create one file per tensor.
|
|
# tensor_proto.raw_data is stored to external file at
|
|
# os.path.join(basepath, relative_tensor_file_path).
|
|
tensor_proto = _create_tensor_proto_with_external_data(
|
|
tensor, refined_name, relative_tensor_file_path, basepath
|
|
)
|
|
# Add the tensor_proto to the ONNX model as an initializer with external data.
|
|
onnx_model_with_initializers.graph.initializer.append(tensor_proto)
|
|
|
|
# model_location should be a pure file name such as "file_name.onnx", not "folder/file_name.onnx".
|
|
onnx.save(onnx_model_with_initializers, os.path.join(basepath, model_location))
|