pytorch/torch/onnx/_internal/fx/serialization.py
Justin Chu e880cb2fe0 [ONNX] Remove beartype usage (#130484)
beartype has served us well in identifying type errors and ensuring we call internal functions with the correct arguments (thanks!). However, the value of having beartype is diminished because of the following:

1. When beartype improves support for better Dict[] type checking, it discovered typing mistakes in some functions that were previously uncaught. This caused the exporter to fail with newer versions beartype when it used to succeed. Since we cannot fix PyTorch and release a new version just because of this, it creates confusion for users that have beartype in their environment from using torch.onnx
2. beartype adds an additional call line in the traceback, which makes the already thick dynamo stack even larger, affecting readability when users diagnose errors with the traceback.
3. Since the typing annotations need to be evaluated, we cannot use new syntaxes like `|` because we need to maintain compatibility with Python 3.8. We don't want to wait for PyTorch take py310 as the lowest supported Python before using the new typing syntaxes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130484
Approved by: https://github.com/titaiwangms
2024-07-18 22:07:40 +00:00

245 lines
11 KiB
Python

# mypy: allow-untyped-defs
from __future__ import annotations
import io
import logging
import os
from typing import Optional, Tuple, TYPE_CHECKING, Union
import torch
from torch.onnx import _type_utils as jit_type_utils
if TYPE_CHECKING:
import onnx
log = logging.getLogger(__name__)
def _create_tensor_proto_with_external_data(
tensor: torch.Tensor,
name: str,
location: str,
basepath: str,
dtype_override: Optional[onnx.TypeProto] = None, # type: ignore[name-defined]
) -> onnx.TensorProto: # type: ignore[name-defined]
"""Create a TensorProto with external data from a PyTorch tensor.
The external data is saved to os.path.join(basepath, location).
Args:
tensor: Tensor to be saved.
name: Name of the tensor (i.e., initializer name in ONNX graph).
location: Relative location of the external data file
(e.g., "/tmp/initializers/weight_0" when model is "/tmp/model_name.onnx").
basepath: Base path of the external data file (e.g., "/tmp/external_data" while model must be in "/tmp").
Reference for ONNX's external data format:
How to load?
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L187
How to save?
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L43
How to set ONNX fields?
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L88
"""
# FIXME: Avoid importing onnx into torch.onnx.
import onnx
scalar_type = (
jit_type_utils.JitScalarType.from_onnx_type(
dtype_override.tensor_type.elem_type
)
if dtype_override is not None
else jit_type_utils.JitScalarType.from_dtype(tensor.dtype)
)
# Checkpoints can be stored with a different dtype as the model expects because
# the user script can explicitly cast the original type to something or maybe
# PyTorch's type promotion might do it
if dtype_override is not None and scalar_type.dtype() != tensor.dtype:
tensor = tensor.to(scalar_type.dtype())
tensor_proto = onnx.TensorProto() # type: ignore[attr-defined]
tensor_proto.name = name
tensor_proto.data_type = scalar_type.onnx_type()
tensor_proto.dims.extend(tensor.shape)
tensor_proto.data_location = onnx.TensorProto.EXTERNAL # type: ignore[attr-defined]
# Settings for saving one tensor per file.
# Offset is zero because there is no other tensor in the same file.
key_value_pairs = {
"location": location,
"offset": 0,
"length": tensor.untyped_storage().nbytes(),
}
for k, v in key_value_pairs.items():
entry = tensor_proto.external_data.add()
entry.key = k
entry.value = str(v)
# Actual path to write content of tensor.
external_data_file_path = os.path.join(basepath, location)
if os.path.exists(external_data_file_path):
os.remove(external_data_file_path)
# Create external data's folder if not exists.
external_data_dir_path = os.path.dirname(external_data_file_path)
if not os.path.exists(external_data_dir_path):
# if the demo_folder directory is not present
# then create it.
os.makedirs(external_data_dir_path)
# Create a fresh file.
with open(external_data_file_path, "xb") as data_file:
# No need to call "seek" because offset is 0.
# data_file.seek(0)
# Write tensor content to the file.
data_file.write(tensor.numpy(force=True).tobytes())
return tensor_proto
def _convert_safetensors_to_torch_format(safetensors_file):
# It this function is called, safetensors is guaranteed to exist
# because the HF model with safetensors was already loaded and exported to ONNX
from safetensors import safe_open # type: ignore[import-not-found]
tensors = {}
with safe_open(safetensors_file, framework="pt", device="cpu") as f: # type: ignore[attr-defined]
for k in f.keys():
tensors[k] = f.get_tensor(k).cpu()
return tensors
# TODO: generalize to allow more checkpoints formats (torch or gguf)
def save_model_with_external_data(
basepath: str,
model_location: str,
initializer_location: str,
torch_state_dicts: Tuple[Union[dict, str, io.BytesIO], ...],
onnx_model: onnx.ModelProto, # type: ignore[name-defined]
rename_initializer: bool = False,
) -> None:
"""Load PyTorch tensors from files and add to "onnx_model" as external initializers.
Output files:
ONNX model file path:
ONNX initializer folder: os.path.join(basepath, initializer_location)
After running this function, you can do
ort_sess = onnxruntime.InferenceSession(os.path.join(basepath, model_location))
to execute the model.
Arguments:
basepath: Base path of the ONNX external data file (e.g., "/path/to/large_model/").
model_location: Relative location of the ONNX model file.
E.g., "model.onnx" so that the model file is saved to
"<basepath>/model.onnx".
initializer_location: Relative location of the ONNX initializer folder.
E.g., "initializers" so that the initializers are saved to
"<basepath>/initializers/".
Note: When initializers are >2GB, must be the same as `model_location`.
torch_state_dicts: Dictionaries or files which contain PyTorch tensors to be saved
as ONNX initializers. For non-dict arguments, `torch.load` will be used to load them from file-like objects.
onnx_model: ONNX model to be saved with external initializers.
If an input name matches a tensor loaded from "torch_state_dicts",
the tensor will be saved as that input's external initializer.
rename_initializer: Replaces "." by "_" for all ONNX initializer names.
Not needed by the official torch.onnx.dynamo_export. This is a hack
for supporting `FXSymbolicTracer` tracer with fake tensor mode.
In short, `FXSymbolicTracer` lifts FX parameters (self.linear_weight)
as inputs (`def forward(self, linear_weight)`) and therefore, `.` cannot be used.
"""
# FIXME: Avoid importing onnx into torch.onnx.
import onnx
initializers_to_be_deleted = {} # Using dict because it is **ordered**
existing_initializers = {
k.name: idx for idx, k in enumerate(onnx_model.graph.initializer)
}
onnx_input_names = {input.name for input in onnx_model.graph.input}
for el in torch_state_dicts:
if isinstance(el, dict):
# Useful for when state_dict is loaded with torch.load(..., mmap=True, map_location="cpu") by the user
# Using torch.save wouldn't leverage mmap, leading to higher memory usage
state_dict = el
else:
if isinstance(el, str) and el.endswith(".safetensors"):
state_dict = _convert_safetensors_to_torch_format(el)
else:
try:
# Loads checkpoint using memory-map on CPU to support really large models
# The underlying torch.UntypedStorage is memory mapped, so state_dict is lazy loaded
state_dict = torch.load(el, map_location="cpu", mmap=True)
except (RuntimeError, ValueError) as e:
if "mmap can only be used with files saved with" in str(
e
) or isinstance(el, io.BytesIO):
log.warning(
"Failed to load the checkpoint with memory-map enabled, retrying without memory-map."
"Consider updating the checkpoint with mmap by using torch.save() on PyTorch version >= 1.6."
)
if isinstance(el, io.BytesIO):
el.seek(0) # torch.load from `try:` has read the file.
state_dict = torch.load(el, map_location="cpu")
else:
raise e
for name, tensor in state_dict.items():
if rename_initializer:
# Basically, "transformer.attention.self.query.weight" is mapped
# to "transformer_attention_self_query_weight" for mimicking the
# name-modifying code in FX-to-ONNX exporter.
# See function _replace_get_attr_with_placeholder for details.
name = name.replace(".", "_")
# This block tries to match the onnx initializer name with torch parameter/buffer
# e.g. A pytorch buffer 'transformer.h.0.attn.bias' can be named 'h.0.attn.bias' in a ONNX initializer
# For each PyTorch tensor name loaded by torch.load,
# 1. Search its best match in ONNX model. E.g., the match of
# "transformer_attention_weight" could be "attention_weight".
# 2. Set "tensor" as the initializer of the matched ONNX input.
# E.g., "tensor" is stored as the initializer of "attention_weight".
# Step 1 is required because sometimes, tensor names are stored with prefix the dictionary
# loaded by torch.load.
if name in onnx_input_names:
# Same input name shouldn't be matched again
onnx_input_names.remove(name)
else:
for onnx_input_name in onnx_input_names:
if onnx_input_name.endswith(name) or name.endswith(onnx_input_name):
# Find a match. Change name to the matched ONNX input name, so that we
# create initializer with the right ONNX name.
name = onnx_input_name
onnx_input_names.remove(onnx_input_name)
break
relative_tensor_file_path = os.path.join(initializer_location, name)
# Create one file per tensor.
# tensor_proto.raw_data is stored to external file at
# os.path.join(basepath, relative_tensor_file_path).
model_input_types = {k.name: k.type for k in onnx_model.graph.input}
# Mark for deletion - a replacement will be appended next
if name in existing_initializers:
initializers_to_be_deleted[existing_initializers[name]] = name
tensor_proto = _create_tensor_proto_with_external_data(
tensor,
name,
relative_tensor_file_path,
basepath,
model_input_types.pop(name, None),
)
# Add the tensor_proto to the ONNX model as an initializer with external data.
onnx_model.graph.initializer.append(tensor_proto)
# Remove old duplicated initializers, if any. delete in desc order to not invalidate deletion indices
initializers_to_be_deleted = dict(
sorted(initializers_to_be_deleted.items(), reverse=True)
)
for idx in initializers_to_be_deleted.keys():
del onnx_model.graph.initializer[idx]
# model_location should be a pure file name such as "file_name.onnx", not "folder/file_name.onnx".
onnx.save(onnx_model, os.path.join(basepath, model_location)) # type: ignore[attr-defined]