from __future__ import annotations import os from typing import Tuple import onnx import torch from torch.onnx._internal import _beartype @_beartype.beartype def _create_tensor_proto_with_external_data( tensor: torch.Tensor, name: str, location: str, basepath: str ) -> "onnx.TensorProto": """Create a TensorProto with external data from a PyTorch tensor. The external data is saved to os.path.join(basepath, location). Args: tensor: Tensor to be saved. name: Name of the tensor (i.e., initializer name in ONNX graph). location: Relative location of the external data file (e.g., "/tmp/initializers/weight_0" when model is "/tmp/model_name.onnx"). basepath: Base path of the external data file (e.g., "/tmp/external_data" while model must be in "/tmp"). Reference for ONNX's external data format: How to load? https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L187 How to save? https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L43 How to set ONNX fields? https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L88 """ tensor_proto = onnx.TensorProto() tensor_proto.name = name tensor_proto.data_type = torch.onnx._type_utils._SCALAR_TYPE_TO_ONNX[ # type: ignore[assignment] torch.onnx._type_utils._DTYPE_TO_SCALAR_TYPE[tensor.dtype] ] tensor_proto.dims.extend(tensor.shape) tensor_proto.data_location = onnx.TensorProto.EXTERNAL # Settings for saving one tensor per file. # Offset is zero because there is no other tensor in the same file. key_value_pairs = { "location": location, "offset": 0, "length": tensor.untyped_storage().nbytes(), } for k, v in key_value_pairs.items(): entry = tensor_proto.external_data.add() entry.key = k entry.value = str(v) # Actual path to write content of tensor. external_data_file_path = os.path.join(basepath, location) if os.path.exists(external_data_file_path): os.remove(external_data_file_path) # Create external data's folder if not exists. external_data_dir_path = os.path.dirname(external_data_file_path) if not os.path.exists(external_data_dir_path): # if the demo_folder directory is not present # then create it. os.makedirs(external_data_dir_path) # Create a fresh file. with open(external_data_file_path, "xb") as data_file: # No need to call "seek" because offset is 0. # data_file.seek(0) # Write tensor content to the file. data_file.write(tensor.numpy().tobytes()) return tensor_proto @_beartype.beartype def save_model_with_external_data( basepath: str, model_location: str, initializer_location: str, torch_load_paths: Tuple[str, ...], onnx_model: onnx.ModelProto, ) -> None: """Load PyTorch tensors from files and add to "onnx_model" as external initializers. Output files: ONNX model file path: ONNX initializer folder: os.path.join(basepath, initializer_location) After running this function, you can do ort_sess = onnxruntime.InferenceSession(os.path.join(basepath, model_location)) to execute the model. Arguments: basepath: Base path of the external data file (e.g., "/tmp/large-onnx-model"). model_location: Relative location of the ONNX model file. E.g., "model.onnx" so that the model file is saved to "/tmp/large-onnx-model/model.onnx". initializer_location: Relative location of the ONNX initializer folder. E.g., "initializers" so that the initializers are saved to "/tmp/large-onnx-model/initializers". torch_load_paths: Files which containing serialized PyTorch tensors to be saved as ONNX initializers. They are loaded by torch.load. onnx_model: ONNX model to be saved with external initializers. If an input name matches a tensor loaded from "torch_load_paths", the tensor will be saved as that input's external initializer. """ onnx_model_with_initializers = onnx.ModelProto() onnx_model_with_initializers.CopyFrom(onnx_model) onnx_input_names = [input.name for input in onnx_model.graph.input] for path in torch_load_paths: state_ditc = torch.load(path) for name, tensor in state_ditc.items(): # Basically, "transformer.attention.self.query.weight" is mapped # to "transformer_attention_self_query_weight" for mimicking the # name-modifying code in FX-to-ONNX exporter. # See function _replace_get_attr_with_placeholder for details. refined_name = name.replace(".", "_") # For each refined PyTorch tensor name loaded by torch.load, # 1. Search its best match in ONNX model. E.g., the match of # "transformer_attention_weight" could be "attention_weight". # 2. Set "tensor" as the initializer of the matched ONNX input. # E.g., "tensor" is stored as the initializer of "attention_weight". # Step 1 is required because sometimes, tensor names are stored with prefix the dictionary # loaded by torch.load. for onnx_input_name in onnx_input_names: if onnx_input_name.endswith(refined_name) or refined_name.endswith( onnx_input_name ): # Find a match. Change refined_name to the matched ONNX input name, so that we # create initializer with the right ONNX name. refined_name = onnx_input_name break relative_tensor_file_path = os.path.join(initializer_location, refined_name) # Create one file per tensor. # tensor_proto.raw_data is stored to external file at # os.path.join(basepath, relative_tensor_file_path). tensor_proto = _create_tensor_proto_with_external_data( tensor, refined_name, relative_tensor_file_path, basepath ) # Add the tensor_proto to the ONNX model as an initializer with external data. onnx_model_with_initializers.graph.initializer.append(tensor_proto) # model_location should be a pure file name such as "file_name.onnx", not "folder/file_name.onnx". onnx.save(onnx_model_with_initializers, os.path.join(basepath, model_location))