mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[ONNX] Move symbolic export to separate file (#95650)
Move things around in the effort of preparing to refactor the code structure. Pull Request resolved: https://github.com/pytorch/pytorch/pull/95650 Approved by: https://github.com/titaiwangms
This commit is contained in:
parent
d06729746c
commit
82dba844bb
|
|
@ -1,10 +1,7 @@
|
|||
from .context import FxToOnnxContext
|
||||
from .exporter import (
|
||||
export,
|
||||
export_after_normalizing_args_and_kwargs,
|
||||
export_without_parameters_and_buffers,
|
||||
save_model_with_external_data,
|
||||
)
|
||||
from .exporter import export, export_after_normalizing_args_and_kwargs
|
||||
from .serialization import save_model_with_external_data
|
||||
from .symbolic_exporter import export_without_parameters_and_buffers
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
|
|
|||
|
|
@ -1,11 +1,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
import itertools
|
||||
import operator
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from types import FunctionType
|
||||
|
|
@ -75,117 +73,6 @@ onnxscript.OnnxFunction.__call__ = _diagnose_onnx_function(
|
|||
)
|
||||
|
||||
|
||||
class ModuleExpansionTracer(torch.fx._symbolic_trace.Tracer):
|
||||
"""Tracer to create ONNX-exporting friendly FX graph.
|
||||
|
||||
This tracer traces models into operators. That is,
|
||||
the traced graph mostly contains call_function nodes and
|
||||
has no call_module nodes. The call_module nodes
|
||||
are problematic to the use of make_fx(...) in ONNX
|
||||
exporter.
|
||||
"""
|
||||
|
||||
@_beartype.beartype
|
||||
def is_leaf_module(
|
||||
self, module: torch.nn.Module, module_qualified_name: str
|
||||
) -> bool:
|
||||
# This returns False so that all sub-modules are considered as not leaves
|
||||
# and therefore expanded into operators in
|
||||
# torch.fx._symbolic_trace.Tracer.call_module.
|
||||
return False
|
||||
|
||||
@_beartype.beartype
|
||||
def to_bool(self, obj: "torch.fx.Proxy") -> bool:
|
||||
# This is a hack to tracing through if-else Python blocks.
|
||||
# It may generate incorrect ONNX graphs if the if-else block
|
||||
return False
|
||||
|
||||
|
||||
# Functions directly wrapped to produce torch.fx.Proxy so that symbolic
|
||||
# data can flow through those functions. Python functions (e.g., `torch.arange`)
|
||||
# not defined by pybind11 in C++ do not go though Python dispatcher, so
|
||||
# they are not automatically patched by FX's Python dispatcher.
|
||||
# The list below means `torch.arange`, `torch.tensor`, and so on will be
|
||||
# patched.
|
||||
_TORCH_METHODS_TO_PATCH: Tuple[str, ...] = (
|
||||
"arange",
|
||||
"tensor",
|
||||
"finfo",
|
||||
"full",
|
||||
"empty",
|
||||
)
|
||||
|
||||
|
||||
def _wrap_for_symbolic_trace(target: Callable) -> Tuple[Callable, Callable]:
|
||||
"""This function wraps ```target`` for symbolic tracing.
|
||||
|
||||
This function wraps ```target``` so that its wrapper produces
|
||||
torch.fx.Proxy in symbolic computation. The returned values are
|
||||
the wrapper and then the original function. Per `_TORCH_METHODS_TO_PATCH`,
|
||||
this function shall receive `torch.arange`, `torch.tensor`, etc. as inputs.
|
||||
"""
|
||||
|
||||
@functools.wraps(target)
|
||||
def wrapper(*args, **kwargs):
|
||||
proxy = None
|
||||
|
||||
def check_has_proxy(v):
|
||||
if isinstance(v, torch.fx.Proxy):
|
||||
nonlocal proxy
|
||||
proxy = v
|
||||
|
||||
torch.fx.node.map_aggregate(args, check_has_proxy)
|
||||
torch.fx.node.map_aggregate(kwargs, check_has_proxy)
|
||||
|
||||
if proxy is not None:
|
||||
return proxy.tracer.create_proxy("call_function", target, args, kwargs)
|
||||
else:
|
||||
return target(*args, **kwargs)
|
||||
|
||||
return wrapper, target
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _module_expansion_symbolic_trace(
|
||||
root: Union[torch.nn.Module, Callable[..., Any]],
|
||||
concrete_args: Optional[Dict[str, Any]] = None,
|
||||
) -> "torch.fx.GraphModule":
|
||||
"""Trace a callable into FX graph.
|
||||
|
||||
When "root" is torch.nn.Module, calls to its submodule (type: torch.nn.Module) will be
|
||||
expanded into operators (e.g., torch.matmul, torch.add, +, and -) to simplify graph
|
||||
structure.
|
||||
"""
|
||||
# For functions doesn't support symbolic tracing, create wrappers
|
||||
# which produce symbolic results during tracing.
|
||||
patched_torch_methods = {
|
||||
target_name: _wrap_for_symbolic_trace(getattr(torch, target_name))
|
||||
for target_name in _TORCH_METHODS_TO_PATCH
|
||||
}
|
||||
|
||||
# Set the symbolic-tracing friendly functions so that `tracer.trace` below
|
||||
# can work.
|
||||
for name, (wrapper, _) in patched_torch_methods.items():
|
||||
setattr(torch, name, wrapper)
|
||||
|
||||
try:
|
||||
# Set up a tracer.
|
||||
tracer = ModuleExpansionTracer()
|
||||
# Trace the model.
|
||||
graph = tracer.trace(root, concrete_args)
|
||||
name = (
|
||||
root.__class__.__name__
|
||||
if isinstance(root, torch.nn.Module)
|
||||
else root.__name__
|
||||
)
|
||||
return torch.fx.GraphModule(tracer.root, graph, name)
|
||||
finally:
|
||||
# Revert the patches for symbolic tracing.
|
||||
for name, (_, wrapped) in patched_torch_methods.items():
|
||||
# wrapped is the original version of `torch.name`.
|
||||
setattr(torch, name, wrapped)
|
||||
|
||||
|
||||
def _retrieve_or_adapt_input_to_graph_set(fx_node_arg, fx_name_to_onnxscipt_value):
|
||||
"""Map FX value to TorchScript value.
|
||||
|
||||
|
|
@ -780,309 +667,6 @@ def export_after_normalizing_args_and_kwargs(
|
|||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _move_placeholder_to_front(graph_module: "torch.fx.GraphModule") -> None:
|
||||
"""
|
||||
This function move all placeholder nodes to the front of the graph node list.
|
||||
In torch.fx.Graph, placeholder is a special assignment node. If it's not
|
||||
executed in the beginning, it could overwrite values computed by upstream
|
||||
nodes.
|
||||
"""
|
||||
|
||||
graph = graph_module.graph
|
||||
placeholders = []
|
||||
first_not_placeholder = None
|
||||
for node in graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
placeholders.append(node)
|
||||
if first_not_placeholder is None and node.op != "placeholder":
|
||||
first_not_placeholder = node
|
||||
if first_not_placeholder is None:
|
||||
return
|
||||
for placeholder in placeholders:
|
||||
first_not_placeholder.prepend(placeholder)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _replace_get_attr_with_placeholder(
|
||||
graph_module: "torch.fx.GraphModule",
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Replace get_attr with placeholder.
|
||||
The parameters and buffers accessed by the original get_attr are returned;
|
||||
they are useful when creating random inputs for the modified graph_module.
|
||||
"""
|
||||
graph = graph_module.graph
|
||||
replaced_attrs: List[torch.Tensor] = []
|
||||
for node in graph.nodes:
|
||||
if node.op == "get_attr":
|
||||
replaced_attr: Optional[torch.Tensor] = None
|
||||
# get_attr could retrieve either parameter or buffer, so
|
||||
# we need to try both.
|
||||
try:
|
||||
replaced_attr = graph_module.get_parameter(node.target)
|
||||
except AttributeError:
|
||||
# It's possible that model author use buffer instead of
|
||||
# parameter to store trainable weights. In this case,
|
||||
# 1. get_parameter will throw something like
|
||||
# AttributeError: `bias` is not an nn.Parameter.
|
||||
# 2. get_buffer should work.
|
||||
replaced_attr = graph_module.get_buffer(node.target)
|
||||
|
||||
# Reassign op type so that get_attr node becomes placeholder node.
|
||||
node.op = "placeholder"
|
||||
# The target name in placeholder must be a valid Python identifier.
|
||||
# Thus, we replace, e.g., "module.submodule.weight" with
|
||||
# "module_submodule_weight".
|
||||
node.target = node.target.replace(".", "_")
|
||||
# Default value is None. This is needed as long as the "graph_module"
|
||||
# has optional inputs. Assume the original forward signature is
|
||||
# def forward(self, x, y=None)
|
||||
# and the replaced get_attr node has target "z". Then, the modified
|
||||
# signature should be
|
||||
# def forward(self, x, y=None, z=None)
|
||||
# Without the following line, the signature will be
|
||||
# def forward(self, x, y=None, z)
|
||||
# , which is not valid Python code.
|
||||
node.args = (None,)
|
||||
|
||||
replaced_attrs.append(replaced_attr)
|
||||
|
||||
return tuple(replaced_attrs)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _trace_into_fx_graph_via_fx_symbolic_trace(
|
||||
module: torch.nn.Module,
|
||||
*args,
|
||||
# kwargs are the keyword arguments to call "module"; that is,
|
||||
# module(*args, **kwargs) must run.
|
||||
**kwargs,
|
||||
) -> Tuple["torch.fx.GraphModule", Tuple[Any, ...]]:
|
||||
signature = inspect.signature(module.forward)
|
||||
|
||||
# We hope the input kwargs will be mapped to bound.args after binding.
|
||||
# If not, we will raise an error.
|
||||
bound = signature.bind(*args, **kwargs)
|
||||
bound.apply_defaults()
|
||||
# After apply_defaults, all non keyword-only arguments are in bound.args.
|
||||
# Because below code do not support keyword-word arguments, bound.kwargs
|
||||
# must be empty.
|
||||
assert len(bound.kwargs) == 0, bound.kwargs
|
||||
|
||||
# Create inputs to call symbolic trace (torch.fx.symbolic_trace)
|
||||
# Example content of concrete_args:
|
||||
# concrete_args["x"] = torch.fx._symbolic_trace.PH
|
||||
# concrete_args["b"] = 1
|
||||
# where "x" and "b" are argument names in "signature".
|
||||
concrete_args = {}
|
||||
for param_name, param_value in bound.arguments.items():
|
||||
if isinstance(param_value, torch.Tensor):
|
||||
# param_value can be, e.g., a real tensor or a fake tensor.
|
||||
# param_value is treated as substitutable tensor symbol (aka placeholder).
|
||||
concrete_args[param_name] = torch.fx._symbolic_trace.PH
|
||||
else:
|
||||
concrete_args[param_name] = param_value
|
||||
|
||||
return (
|
||||
_module_expansion_symbolic_trace(module, concrete_args=concrete_args),
|
||||
bound.args,
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def export_without_parameters_and_buffers(
|
||||
module: torch.nn.Module,
|
||||
*args,
|
||||
decomposition_table: Optional[Dict[torch._ops.OpOverload, Callable]] = None,
|
||||
use_binary_format: bool = True,
|
||||
opset_version: int = _constants.ONNX_DEFAULT_OPSET,
|
||||
op_level_debug: bool = False,
|
||||
# kwargs are the keyword arguments to call "module"; that is,
|
||||
# module(*args, **kwargs) must run.
|
||||
**kwargs,
|
||||
) -> Tuple[
|
||||
Union["onnx.ModelProto", bytes],
|
||||
"torch.fx.GraphModule",
|
||||
Tuple[Any, ...],
|
||||
Tuple[Any, ...],
|
||||
]:
|
||||
|
||||
graph_module, bound_args = _trace_into_fx_graph_via_fx_symbolic_trace(
|
||||
module, *args, **kwargs
|
||||
)
|
||||
|
||||
# Make sure all placeholder nodes are executed before get_attr nodes.
|
||||
# Otherwise, inputs can interleave with initializers in the final ModeoProto.graph.input.
|
||||
# Basically, we want
|
||||
# ModeoProto.graph.input =
|
||||
# [input_0, input_1, ..., input_n, weight_0, weight_1, ..., weight_m]
|
||||
# and we don't want
|
||||
# ModeoProto.graph.input =
|
||||
# [input_0, weight_0, input_1, weight_1, ..., input_n, weight_0, weight_1, ..., weight_m]
|
||||
_move_placeholder_to_front(graph_module)
|
||||
# To save memory, move get_attr to input so that the generated model doesn't
|
||||
# have weigh tensors. "replaced_attrs" are the list of replaced weight tensors.
|
||||
replaced_attrs = _replace_get_attr_with_placeholder(graph_module)
|
||||
# Move all newly created placeholder nodes to the front of the graph.
|
||||
_move_placeholder_to_front(graph_module)
|
||||
# Finalize the graph editing.
|
||||
graph_module.recompile()
|
||||
return (
|
||||
_export(
|
||||
graph_module,
|
||||
(*bound_args, *replaced_attrs),
|
||||
opset_version=opset_version,
|
||||
decomposition_table=decomposition_table,
|
||||
use_binary_format=use_binary_format,
|
||||
op_level_debug=op_level_debug,
|
||||
),
|
||||
graph_module,
|
||||
bound_args,
|
||||
replaced_attrs,
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _create_tensor_proto_with_external_data(
|
||||
tensor: torch.Tensor, name: str, location: str, basepath: str
|
||||
) -> "onnx.TensorProto":
|
||||
"""Create a TensorProto with external data from a PyTorch tensor.
|
||||
The external data is saved to os.path.join(basepath, location).
|
||||
|
||||
Args:
|
||||
tensor: Tensor to be saved.
|
||||
name: Name of the tensor (i.e., initializer name in ONNX graph).
|
||||
location: Relative location of the external data file
|
||||
(e.g., "/tmp/initializers/weight_0" when model is "/tmp/model_name.onnx").
|
||||
basepath: Base path of the external data file (e.g., "/tmp/external_data" while model must be in "/tmp").
|
||||
|
||||
|
||||
Reference for ONNX's external data format:
|
||||
How to load?
|
||||
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L187
|
||||
How to save?
|
||||
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L43
|
||||
How to set ONNX fields?
|
||||
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L88
|
||||
"""
|
||||
tensor_proto = onnx.TensorProto()
|
||||
tensor_proto.name = name
|
||||
tensor_proto.data_type = torch.onnx._type_utils._SCALAR_TYPE_TO_ONNX[ # type: ignore[assignment]
|
||||
torch.onnx._type_utils._DTYPE_TO_SCALAR_TYPE[tensor.dtype]
|
||||
]
|
||||
tensor_proto.dims.extend(tensor.shape)
|
||||
tensor_proto.data_location = onnx.TensorProto.EXTERNAL
|
||||
|
||||
# Settings for saving one tensor per file.
|
||||
# Offset is zero because there is no other tensor in the same file.
|
||||
key_value_pairs = {
|
||||
"location": location,
|
||||
"offset": 0,
|
||||
"length": tensor.untyped_storage().nbytes(),
|
||||
}
|
||||
for k, v in key_value_pairs.items():
|
||||
entry = tensor_proto.external_data.add()
|
||||
entry.key = k
|
||||
entry.value = str(v)
|
||||
|
||||
# Actual path to write content of tensor.
|
||||
external_data_file_path = os.path.join(basepath, location)
|
||||
if os.path.exists(external_data_file_path):
|
||||
os.remove(external_data_file_path)
|
||||
|
||||
# Create external data's folder if not exists.
|
||||
external_data_dir_path = os.path.dirname(external_data_file_path)
|
||||
if not os.path.exists(external_data_dir_path):
|
||||
# if the demo_folder directory is not present
|
||||
# then create it.
|
||||
os.makedirs(external_data_dir_path)
|
||||
|
||||
# Create a fresh file.
|
||||
with open(external_data_file_path, "xb") as data_file:
|
||||
# No need to call "seek" because offset is 0.
|
||||
# data_file.seek(0)
|
||||
# Write tensor content to the file.
|
||||
data_file.write(tensor.numpy().tobytes())
|
||||
|
||||
return tensor_proto
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def save_model_with_external_data(
|
||||
basepath: str,
|
||||
model_location: str,
|
||||
initializer_location: str,
|
||||
torch_load_paths: Tuple[str, ...],
|
||||
onnx_model: "onnx.ModelProto",
|
||||
) -> None:
|
||||
"""Load PyTorch tensors from files and add to "onnx_model" as external initializers.
|
||||
|
||||
Output files:
|
||||
ONNX model file path:
|
||||
ONNX initializer folder: os.path.join(basepath, initializer_location)
|
||||
|
||||
After running this function, you can do
|
||||
ort_sess = onnxruntime.InferenceSession(os.path.join(basepath, model_location))
|
||||
to execute the model.
|
||||
|
||||
Arguments:
|
||||
basepath: Base path of the external data file (e.g., "/tmp/large-onnx-model").
|
||||
model_location: Relative location of the ONNX model file.
|
||||
E.g., "model.onnx" so that the model file is saved to
|
||||
"/tmp/large-onnx-model/model.onnx".
|
||||
initializer_location: Relative location of the ONNX initializer folder.
|
||||
E.g., "initializers" so that the initializers are saved to
|
||||
"/tmp/large-onnx-model/initializers".
|
||||
torch_load_paths: Files which containing serialized PyTorch tensors to be saved
|
||||
as ONNX initializers. They are loaded by torch.load.
|
||||
onnx_model: ONNX model to be saved with external initializers.
|
||||
If an input name matches a tensor loaded from "torch_load_paths",
|
||||
the tensor will be saved as that input's external initializer.
|
||||
"""
|
||||
onnx_model_with_initializers = onnx.ModelProto()
|
||||
onnx_model_with_initializers.CopyFrom(onnx_model)
|
||||
onnx_input_names = [input.name for input in onnx_model.graph.input]
|
||||
|
||||
for path in torch_load_paths:
|
||||
state_ditc = torch.load(path)
|
||||
for name, tensor in state_ditc.items():
|
||||
# Basically, "transformer.attention.self.query.weight" is mapped
|
||||
# to "transformer_attention_self_query_weight" for mimicking the
|
||||
# name-modifying code in FX-to-ONNX exporter.
|
||||
# See function _replace_get_attr_with_placeholder for details.
|
||||
refined_name = name.replace(".", "_")
|
||||
|
||||
# For each refined PyTorch tensor name loaded by torch.load,
|
||||
# 1. Search its best match in ONNX model. E.g., the match of
|
||||
# "transformer_attention_weight" could be "attention_weight".
|
||||
# 2. Set "tensor" as the initializer of the matched ONNX input.
|
||||
# E.g., "tensor" is stored as the initializer of "attention_weight".
|
||||
# Step 1 is required because sometimes, tensor names are stored with prefix the dictionary
|
||||
# loaded by torch.load.
|
||||
for onnx_input_name in onnx_input_names:
|
||||
if onnx_input_name.endswith(refined_name) or refined_name.endswith(
|
||||
onnx_input_name
|
||||
):
|
||||
# Find a match. Change refined_name to the matched ONNX input name, so that we
|
||||
# create initializer with the right ONNX name.
|
||||
refined_name = onnx_input_name
|
||||
break
|
||||
|
||||
relative_tensor_file_path = os.path.join(initializer_location, refined_name)
|
||||
# Create one file per tensor.
|
||||
# tensor_proto.raw_data is stored to external file at
|
||||
# os.path.join(basepath, relative_tensor_file_path).
|
||||
tensor_proto = _create_tensor_proto_with_external_data(
|
||||
tensor, refined_name, relative_tensor_file_path, basepath
|
||||
)
|
||||
# Add the tensor_proto to the ONNX model as an initializer with external data.
|
||||
onnx_model_with_initializers.graph.initializer.append(tensor_proto)
|
||||
|
||||
# model_location should be a pure file name such as "file_name.onnx", not "folder/file_name.onnx".
|
||||
onnx.save(onnx_model_with_initializers, os.path.join(basepath, model_location))
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _validate_op_between_ort_torch(
|
||||
node: torch.fx.Node,
|
||||
|
|
|
|||
149
torch/onnx/_internal/fx/serialization.py
Normal file
149
torch/onnx/_internal/fx/serialization.py
Normal file
|
|
@ -0,0 +1,149 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Tuple
|
||||
|
||||
import onnx
|
||||
|
||||
import torch
|
||||
from torch.onnx._internal import _beartype
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _create_tensor_proto_with_external_data(
|
||||
tensor: torch.Tensor, name: str, location: str, basepath: str
|
||||
) -> "onnx.TensorProto":
|
||||
"""Create a TensorProto with external data from a PyTorch tensor.
|
||||
The external data is saved to os.path.join(basepath, location).
|
||||
|
||||
Args:
|
||||
tensor: Tensor to be saved.
|
||||
name: Name of the tensor (i.e., initializer name in ONNX graph).
|
||||
location: Relative location of the external data file
|
||||
(e.g., "/tmp/initializers/weight_0" when model is "/tmp/model_name.onnx").
|
||||
basepath: Base path of the external data file (e.g., "/tmp/external_data" while model must be in "/tmp").
|
||||
|
||||
|
||||
Reference for ONNX's external data format:
|
||||
How to load?
|
||||
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L187
|
||||
How to save?
|
||||
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L43
|
||||
How to set ONNX fields?
|
||||
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L88
|
||||
"""
|
||||
tensor_proto = onnx.TensorProto()
|
||||
tensor_proto.name = name
|
||||
tensor_proto.data_type = torch.onnx._type_utils._SCALAR_TYPE_TO_ONNX[ # type: ignore[assignment]
|
||||
torch.onnx._type_utils._DTYPE_TO_SCALAR_TYPE[tensor.dtype]
|
||||
]
|
||||
tensor_proto.dims.extend(tensor.shape)
|
||||
tensor_proto.data_location = onnx.TensorProto.EXTERNAL
|
||||
|
||||
# Settings for saving one tensor per file.
|
||||
# Offset is zero because there is no other tensor in the same file.
|
||||
key_value_pairs = {
|
||||
"location": location,
|
||||
"offset": 0,
|
||||
"length": tensor.untyped_storage().nbytes(),
|
||||
}
|
||||
for k, v in key_value_pairs.items():
|
||||
entry = tensor_proto.external_data.add()
|
||||
entry.key = k
|
||||
entry.value = str(v)
|
||||
|
||||
# Actual path to write content of tensor.
|
||||
external_data_file_path = os.path.join(basepath, location)
|
||||
if os.path.exists(external_data_file_path):
|
||||
os.remove(external_data_file_path)
|
||||
|
||||
# Create external data's folder if not exists.
|
||||
external_data_dir_path = os.path.dirname(external_data_file_path)
|
||||
if not os.path.exists(external_data_dir_path):
|
||||
# if the demo_folder directory is not present
|
||||
# then create it.
|
||||
os.makedirs(external_data_dir_path)
|
||||
|
||||
# Create a fresh file.
|
||||
with open(external_data_file_path, "xb") as data_file:
|
||||
# No need to call "seek" because offset is 0.
|
||||
# data_file.seek(0)
|
||||
# Write tensor content to the file.
|
||||
data_file.write(tensor.numpy().tobytes())
|
||||
|
||||
return tensor_proto
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def save_model_with_external_data(
|
||||
basepath: str,
|
||||
model_location: str,
|
||||
initializer_location: str,
|
||||
torch_load_paths: Tuple[str, ...],
|
||||
onnx_model: onnx.ModelProto,
|
||||
) -> None:
|
||||
"""Load PyTorch tensors from files and add to "onnx_model" as external initializers.
|
||||
|
||||
Output files:
|
||||
ONNX model file path:
|
||||
ONNX initializer folder: os.path.join(basepath, initializer_location)
|
||||
|
||||
After running this function, you can do
|
||||
ort_sess = onnxruntime.InferenceSession(os.path.join(basepath, model_location))
|
||||
to execute the model.
|
||||
|
||||
Arguments:
|
||||
basepath: Base path of the external data file (e.g., "/tmp/large-onnx-model").
|
||||
model_location: Relative location of the ONNX model file.
|
||||
E.g., "model.onnx" so that the model file is saved to
|
||||
"/tmp/large-onnx-model/model.onnx".
|
||||
initializer_location: Relative location of the ONNX initializer folder.
|
||||
E.g., "initializers" so that the initializers are saved to
|
||||
"/tmp/large-onnx-model/initializers".
|
||||
torch_load_paths: Files which containing serialized PyTorch tensors to be saved
|
||||
as ONNX initializers. They are loaded by torch.load.
|
||||
onnx_model: ONNX model to be saved with external initializers.
|
||||
If an input name matches a tensor loaded from "torch_load_paths",
|
||||
the tensor will be saved as that input's external initializer.
|
||||
"""
|
||||
onnx_model_with_initializers = onnx.ModelProto()
|
||||
onnx_model_with_initializers.CopyFrom(onnx_model)
|
||||
onnx_input_names = [input.name for input in onnx_model.graph.input]
|
||||
|
||||
for path in torch_load_paths:
|
||||
state_ditc = torch.load(path)
|
||||
for name, tensor in state_ditc.items():
|
||||
# Basically, "transformer.attention.self.query.weight" is mapped
|
||||
# to "transformer_attention_self_query_weight" for mimicking the
|
||||
# name-modifying code in FX-to-ONNX exporter.
|
||||
# See function _replace_get_attr_with_placeholder for details.
|
||||
refined_name = name.replace(".", "_")
|
||||
|
||||
# For each refined PyTorch tensor name loaded by torch.load,
|
||||
# 1. Search its best match in ONNX model. E.g., the match of
|
||||
# "transformer_attention_weight" could be "attention_weight".
|
||||
# 2. Set "tensor" as the initializer of the matched ONNX input.
|
||||
# E.g., "tensor" is stored as the initializer of "attention_weight".
|
||||
# Step 1 is required because sometimes, tensor names are stored with prefix the dictionary
|
||||
# loaded by torch.load.
|
||||
for onnx_input_name in onnx_input_names:
|
||||
if onnx_input_name.endswith(refined_name) or refined_name.endswith(
|
||||
onnx_input_name
|
||||
):
|
||||
# Find a match. Change refined_name to the matched ONNX input name, so that we
|
||||
# create initializer with the right ONNX name.
|
||||
refined_name = onnx_input_name
|
||||
break
|
||||
|
||||
relative_tensor_file_path = os.path.join(initializer_location, refined_name)
|
||||
# Create one file per tensor.
|
||||
# tensor_proto.raw_data is stored to external file at
|
||||
# os.path.join(basepath, relative_tensor_file_path).
|
||||
tensor_proto = _create_tensor_proto_with_external_data(
|
||||
tensor, refined_name, relative_tensor_file_path, basepath
|
||||
)
|
||||
# Add the tensor_proto to the ONNX model as an initializer with external data.
|
||||
onnx_model_with_initializers.graph.initializer.append(tensor_proto)
|
||||
|
||||
# model_location should be a pure file name such as "file_name.onnx", not "folder/file_name.onnx".
|
||||
onnx.save(onnx_model_with_initializers, os.path.join(basepath, model_location))
|
||||
289
torch/onnx/_internal/fx/symbolic_exporter.py
Normal file
289
torch/onnx/_internal/fx/symbolic_exporter.py
Normal file
|
|
@ -0,0 +1,289 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
|
||||
import inspect
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import onnx
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
|
||||
from torch.onnx import _constants
|
||||
from torch.onnx._internal import _beartype
|
||||
from torch.onnx._internal.fx import exporter
|
||||
|
||||
# Functions directly wrapped to produce torch.fx.Proxy so that symbolic
|
||||
# data can flow through those functions. Python functions (e.g., `torch.arange`)
|
||||
# not defined by pybind11 in C++ do not go though Python dispatcher, so
|
||||
# they are not automatically patched by FX's Python dispatcher.
|
||||
# The list below means `torch.arange`, `torch.tensor`, and so on will be
|
||||
# patched.
|
||||
_TORCH_METHODS_TO_PATCH: Tuple[str, ...] = (
|
||||
"arange",
|
||||
"tensor",
|
||||
"finfo",
|
||||
"full",
|
||||
"empty",
|
||||
)
|
||||
|
||||
|
||||
class ModuleExpansionTracer(torch.fx._symbolic_trace.Tracer):
|
||||
"""Tracer to create ONNX-exporting friendly FX graph.
|
||||
|
||||
This tracer traces models into operators. That is,
|
||||
the traced graph mostly contains call_function nodes and
|
||||
has no call_module nodes. The call_module nodes
|
||||
are problematic to the use of make_fx(...) in ONNX
|
||||
exporter.
|
||||
"""
|
||||
|
||||
@_beartype.beartype
|
||||
def is_leaf_module(
|
||||
self, module: torch.nn.Module, module_qualified_name: str
|
||||
) -> bool:
|
||||
# This returns False so that all sub-modules are considered as not leaves
|
||||
# and therefore expanded into operators in
|
||||
# torch.fx._symbolic_trace.Tracer.call_module.
|
||||
return False
|
||||
|
||||
@_beartype.beartype
|
||||
def to_bool(self, obj: "torch.fx.Proxy") -> bool:
|
||||
# FIXME: This is a hack to tracing through if-else Python blocks.
|
||||
# It may generate incorrect ONNX graphs if the if-else block
|
||||
return False
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _move_placeholder_to_front(graph_module: torch.fx.GraphModule) -> None:
|
||||
"""
|
||||
This function move all placeholder nodes to the front of the graph node list.
|
||||
In torch.fx.Graph, placeholder is a special assignment node. If it's not
|
||||
executed in the beginning, it could overwrite values computed by upstream
|
||||
nodes.
|
||||
"""
|
||||
|
||||
graph = graph_module.graph
|
||||
placeholders = []
|
||||
first_not_placeholder = None
|
||||
for node in graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
placeholders.append(node)
|
||||
if first_not_placeholder is None and node.op != "placeholder":
|
||||
first_not_placeholder = node
|
||||
if first_not_placeholder is None:
|
||||
return
|
||||
for placeholder in placeholders:
|
||||
first_not_placeholder.prepend(placeholder)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _replace_get_attr_with_placeholder(
|
||||
graph_module: torch.fx.GraphModule,
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Replace get_attr with placeholder.
|
||||
The parameters and buffers accessed by the original get_attr are returned;
|
||||
they are useful when creating random inputs for the modified graph_module.
|
||||
"""
|
||||
graph = graph_module.graph
|
||||
replaced_attrs: List[torch.Tensor] = []
|
||||
for node in graph.nodes:
|
||||
if node.op == "get_attr":
|
||||
replaced_attr: Optional[torch.Tensor] = None
|
||||
# get_attr could retrieve either parameter or buffer, so
|
||||
# we need to try both.
|
||||
try:
|
||||
replaced_attr = graph_module.get_parameter(node.target)
|
||||
except AttributeError:
|
||||
# It's possible that model author use buffer instead of
|
||||
# parameter to store trainable weights. In this case,
|
||||
# 1. get_parameter will throw something like
|
||||
# AttributeError: `bias` is not an nn.Parameter.
|
||||
# 2. get_buffer should work.
|
||||
replaced_attr = graph_module.get_buffer(node.target)
|
||||
|
||||
# Reassign op type so that get_attr node becomes placeholder node.
|
||||
node.op = "placeholder"
|
||||
# The target name in placeholder must be a valid Python identifier.
|
||||
# Thus, we replace, e.g., "module.submodule.weight" with
|
||||
# "module_submodule_weight".
|
||||
node.target = node.target.replace(".", "_")
|
||||
# Default value is None. This is needed as long as the "graph_module"
|
||||
# has optional inputs. Assume the original forward signature is
|
||||
# def forward(self, x, y=None)
|
||||
# and the replaced get_attr node has target "z". Then, the modified
|
||||
# signature should be
|
||||
# def forward(self, x, y=None, z=None)
|
||||
# Without the following line, the signature will be
|
||||
# def forward(self, x, y=None, z)
|
||||
# , which is not valid Python code.
|
||||
node.args = (None,)
|
||||
|
||||
replaced_attrs.append(replaced_attr)
|
||||
|
||||
return tuple(replaced_attrs)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _trace_into_fx_graph_via_fx_symbolic_trace(
|
||||
module: torch.nn.Module,
|
||||
*args,
|
||||
# kwargs are the keyword arguments to call "module"; that is,
|
||||
# module(*args, **kwargs) must run.
|
||||
**kwargs,
|
||||
) -> Tuple["torch.fx.GraphModule", Tuple[Any, ...]]:
|
||||
signature = inspect.signature(module.forward)
|
||||
|
||||
# We hope the input kwargs will be mapped to bound.args after binding.
|
||||
# If not, we will raise an error.
|
||||
bound = signature.bind(*args, **kwargs)
|
||||
bound.apply_defaults()
|
||||
# After apply_defaults, all non keyword-only arguments are in bound.args.
|
||||
# Because below code do not support keyword-word arguments, bound.kwargs
|
||||
# must be empty.
|
||||
assert len(bound.kwargs) == 0, bound.kwargs
|
||||
|
||||
# Create inputs to call symbolic trace (torch.fx.symbolic_trace)
|
||||
# Example content of concrete_args:
|
||||
# concrete_args["x"] = torch.fx._symbolic_trace.PH
|
||||
# concrete_args["b"] = 1
|
||||
# where "x" and "b" are argument names in "signature".
|
||||
concrete_args = {}
|
||||
for param_name, param_value in bound.arguments.items():
|
||||
if isinstance(param_value, torch.Tensor):
|
||||
# param_value can be, e.g., a real tensor or a fake tensor.
|
||||
# param_value is treated as substitutable tensor symbol (aka placeholder).
|
||||
concrete_args[param_name] = torch.fx._symbolic_trace.PH
|
||||
else:
|
||||
concrete_args[param_name] = param_value
|
||||
|
||||
return (
|
||||
_module_expansion_symbolic_trace(module, concrete_args=concrete_args),
|
||||
bound.args,
|
||||
)
|
||||
|
||||
|
||||
def _wrap_for_symbolic_trace(target: Callable) -> Tuple[Callable, Callable]:
|
||||
"""This function wraps ```target`` for symbolic tracing.
|
||||
|
||||
This function wraps ```target``` so that its wrapper produces
|
||||
torch.fx.Proxy in symbolic computation. The returned values are
|
||||
the wrapper and then the original function. Per `_TORCH_METHODS_TO_PATCH`,
|
||||
this function shall receive `torch.arange`, `torch.tensor`, etc. as inputs.
|
||||
"""
|
||||
|
||||
@functools.wraps(target)
|
||||
def wrapper(*args, **kwargs):
|
||||
proxy = None
|
||||
|
||||
def check_has_proxy(v):
|
||||
if isinstance(v, torch.fx.Proxy):
|
||||
nonlocal proxy
|
||||
proxy = v
|
||||
|
||||
torch.fx.node.map_aggregate(args, check_has_proxy)
|
||||
torch.fx.node.map_aggregate(kwargs, check_has_proxy)
|
||||
|
||||
if proxy is not None:
|
||||
return proxy.tracer.create_proxy("call_function", target, args, kwargs)
|
||||
else:
|
||||
return target(*args, **kwargs)
|
||||
|
||||
return wrapper, target
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _module_expansion_symbolic_trace(
|
||||
root: Union[torch.nn.Module, Callable[..., Any]],
|
||||
concrete_args: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.fx.GraphModule:
|
||||
"""Trace a callable into FX graph.
|
||||
|
||||
When "root" is torch.nn.Module, calls to its submodule (type: torch.nn.Module) will be
|
||||
expanded into operators (e.g., torch.matmul, torch.add, +, and -) to simplify graph
|
||||
structure.
|
||||
"""
|
||||
# For functions doesn't support symbolic tracing, create wrappers
|
||||
# which produce symbolic results during tracing.
|
||||
patched_torch_methods = {
|
||||
target_name: _wrap_for_symbolic_trace(getattr(torch, target_name))
|
||||
for target_name in _TORCH_METHODS_TO_PATCH
|
||||
}
|
||||
|
||||
# Set the symbolic-tracing friendly functions so that `tracer.trace` below
|
||||
# can work.
|
||||
for name, (wrapper, _) in patched_torch_methods.items():
|
||||
setattr(torch, name, wrapper)
|
||||
|
||||
try:
|
||||
# Set up a tracer.
|
||||
tracer = ModuleExpansionTracer()
|
||||
# Trace the model.
|
||||
graph = tracer.trace(root, concrete_args)
|
||||
name = (
|
||||
root.__class__.__name__
|
||||
if isinstance(root, torch.nn.Module)
|
||||
else root.__name__
|
||||
)
|
||||
return torch.fx.GraphModule(tracer.root, graph, name)
|
||||
finally:
|
||||
# Revert the patches for symbolic tracing.
|
||||
for name, (_, wrapped) in patched_torch_methods.items():
|
||||
# wrapped is the original version of `torch.name`.
|
||||
setattr(torch, name, wrapped)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def export_without_parameters_and_buffers(
|
||||
module: torch.nn.Module,
|
||||
*args,
|
||||
decomposition_table: Optional[Dict[torch._ops.OpOverload, Callable]] = None,
|
||||
use_binary_format: bool = True,
|
||||
opset_version: int = _constants.ONNX_DEFAULT_OPSET,
|
||||
op_level_debug: bool = False,
|
||||
# kwargs are the keyword arguments to call "module"; that is,
|
||||
# module(*args, **kwargs) must run.
|
||||
**kwargs,
|
||||
) -> Tuple[
|
||||
Union[onnx.ModelProto, bytes],
|
||||
torch.fx.GraphModule,
|
||||
Tuple[Any, ...],
|
||||
Tuple[Any, ...],
|
||||
]:
|
||||
|
||||
graph_module, bound_args = _trace_into_fx_graph_via_fx_symbolic_trace(
|
||||
module, *args, **kwargs
|
||||
)
|
||||
|
||||
# Make sure all placeholder nodes are executed before get_attr nodes.
|
||||
# Otherwise, inputs can interleave with initializers in the final ModeoProto.graph.input.
|
||||
# Basically, we want
|
||||
# ModeoProto.graph.input =
|
||||
# [input_0, input_1, ..., input_n, weight_0, weight_1, ..., weight_m]
|
||||
# and we don't want
|
||||
# ModeoProto.graph.input =
|
||||
# [input_0, weight_0, input_1, weight_1, ..., input_n, weight_0, weight_1, ..., weight_m]
|
||||
_move_placeholder_to_front(graph_module)
|
||||
# To save memory, move get_attr to input so that the generated model doesn't
|
||||
# have weigh tensors. "replaced_attrs" are the list of replaced weight tensors.
|
||||
replaced_attrs = _replace_get_attr_with_placeholder(graph_module)
|
||||
# Move all newly created placeholder nodes to the front of the graph.
|
||||
_move_placeholder_to_front(graph_module)
|
||||
# Finalize the graph editing.
|
||||
graph_module.recompile()
|
||||
return (
|
||||
exporter._export(
|
||||
graph_module,
|
||||
(*bound_args, *replaced_attrs),
|
||||
opset_version=opset_version,
|
||||
decomposition_table=decomposition_table,
|
||||
use_binary_format=use_binary_format,
|
||||
op_level_debug=op_level_debug,
|
||||
),
|
||||
graph_module,
|
||||
bound_args,
|
||||
replaced_attrs,
|
||||
)
|
||||
Loading…
Reference in New Issue
Block a user