mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[ONNX] Move symbolic export to separate file (#95650)
Move things around in the effort of preparing to refactor the code structure. Pull Request resolved: https://github.com/pytorch/pytorch/pull/95650 Approved by: https://github.com/titaiwangms
This commit is contained in:
parent
d06729746c
commit
82dba844bb
|
|
@ -1,10 +1,7 @@
|
||||||
from .context import FxToOnnxContext
|
from .context import FxToOnnxContext
|
||||||
from .exporter import (
|
from .exporter import export, export_after_normalizing_args_and_kwargs
|
||||||
export,
|
from .serialization import save_model_with_external_data
|
||||||
export_after_normalizing_args_and_kwargs,
|
from .symbolic_exporter import export_without_parameters_and_buffers
|
||||||
export_without_parameters_and_buffers,
|
|
||||||
save_model_with_external_data,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,9 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import functools
|
|
||||||
import inspect
|
import inspect
|
||||||
import itertools
|
import itertools
|
||||||
import operator
|
import operator
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
from types import FunctionType
|
from types import FunctionType
|
||||||
|
|
@ -75,117 +73,6 @@ onnxscript.OnnxFunction.__call__ = _diagnose_onnx_function(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ModuleExpansionTracer(torch.fx._symbolic_trace.Tracer):
|
|
||||||
"""Tracer to create ONNX-exporting friendly FX graph.
|
|
||||||
|
|
||||||
This tracer traces models into operators. That is,
|
|
||||||
the traced graph mostly contains call_function nodes and
|
|
||||||
has no call_module nodes. The call_module nodes
|
|
||||||
are problematic to the use of make_fx(...) in ONNX
|
|
||||||
exporter.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@_beartype.beartype
|
|
||||||
def is_leaf_module(
|
|
||||||
self, module: torch.nn.Module, module_qualified_name: str
|
|
||||||
) -> bool:
|
|
||||||
# This returns False so that all sub-modules are considered as not leaves
|
|
||||||
# and therefore expanded into operators in
|
|
||||||
# torch.fx._symbolic_trace.Tracer.call_module.
|
|
||||||
return False
|
|
||||||
|
|
||||||
@_beartype.beartype
|
|
||||||
def to_bool(self, obj: "torch.fx.Proxy") -> bool:
|
|
||||||
# This is a hack to tracing through if-else Python blocks.
|
|
||||||
# It may generate incorrect ONNX graphs if the if-else block
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
# Functions directly wrapped to produce torch.fx.Proxy so that symbolic
|
|
||||||
# data can flow through those functions. Python functions (e.g., `torch.arange`)
|
|
||||||
# not defined by pybind11 in C++ do not go though Python dispatcher, so
|
|
||||||
# they are not automatically patched by FX's Python dispatcher.
|
|
||||||
# The list below means `torch.arange`, `torch.tensor`, and so on will be
|
|
||||||
# patched.
|
|
||||||
_TORCH_METHODS_TO_PATCH: Tuple[str, ...] = (
|
|
||||||
"arange",
|
|
||||||
"tensor",
|
|
||||||
"finfo",
|
|
||||||
"full",
|
|
||||||
"empty",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _wrap_for_symbolic_trace(target: Callable) -> Tuple[Callable, Callable]:
|
|
||||||
"""This function wraps ```target`` for symbolic tracing.
|
|
||||||
|
|
||||||
This function wraps ```target``` so that its wrapper produces
|
|
||||||
torch.fx.Proxy in symbolic computation. The returned values are
|
|
||||||
the wrapper and then the original function. Per `_TORCH_METHODS_TO_PATCH`,
|
|
||||||
this function shall receive `torch.arange`, `torch.tensor`, etc. as inputs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@functools.wraps(target)
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
proxy = None
|
|
||||||
|
|
||||||
def check_has_proxy(v):
|
|
||||||
if isinstance(v, torch.fx.Proxy):
|
|
||||||
nonlocal proxy
|
|
||||||
proxy = v
|
|
||||||
|
|
||||||
torch.fx.node.map_aggregate(args, check_has_proxy)
|
|
||||||
torch.fx.node.map_aggregate(kwargs, check_has_proxy)
|
|
||||||
|
|
||||||
if proxy is not None:
|
|
||||||
return proxy.tracer.create_proxy("call_function", target, args, kwargs)
|
|
||||||
else:
|
|
||||||
return target(*args, **kwargs)
|
|
||||||
|
|
||||||
return wrapper, target
|
|
||||||
|
|
||||||
|
|
||||||
@_beartype.beartype
|
|
||||||
def _module_expansion_symbolic_trace(
|
|
||||||
root: Union[torch.nn.Module, Callable[..., Any]],
|
|
||||||
concrete_args: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> "torch.fx.GraphModule":
|
|
||||||
"""Trace a callable into FX graph.
|
|
||||||
|
|
||||||
When "root" is torch.nn.Module, calls to its submodule (type: torch.nn.Module) will be
|
|
||||||
expanded into operators (e.g., torch.matmul, torch.add, +, and -) to simplify graph
|
|
||||||
structure.
|
|
||||||
"""
|
|
||||||
# For functions doesn't support symbolic tracing, create wrappers
|
|
||||||
# which produce symbolic results during tracing.
|
|
||||||
patched_torch_methods = {
|
|
||||||
target_name: _wrap_for_symbolic_trace(getattr(torch, target_name))
|
|
||||||
for target_name in _TORCH_METHODS_TO_PATCH
|
|
||||||
}
|
|
||||||
|
|
||||||
# Set the symbolic-tracing friendly functions so that `tracer.trace` below
|
|
||||||
# can work.
|
|
||||||
for name, (wrapper, _) in patched_torch_methods.items():
|
|
||||||
setattr(torch, name, wrapper)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Set up a tracer.
|
|
||||||
tracer = ModuleExpansionTracer()
|
|
||||||
# Trace the model.
|
|
||||||
graph = tracer.trace(root, concrete_args)
|
|
||||||
name = (
|
|
||||||
root.__class__.__name__
|
|
||||||
if isinstance(root, torch.nn.Module)
|
|
||||||
else root.__name__
|
|
||||||
)
|
|
||||||
return torch.fx.GraphModule(tracer.root, graph, name)
|
|
||||||
finally:
|
|
||||||
# Revert the patches for symbolic tracing.
|
|
||||||
for name, (_, wrapped) in patched_torch_methods.items():
|
|
||||||
# wrapped is the original version of `torch.name`.
|
|
||||||
setattr(torch, name, wrapped)
|
|
||||||
|
|
||||||
|
|
||||||
def _retrieve_or_adapt_input_to_graph_set(fx_node_arg, fx_name_to_onnxscipt_value):
|
def _retrieve_or_adapt_input_to_graph_set(fx_node_arg, fx_name_to_onnxscipt_value):
|
||||||
"""Map FX value to TorchScript value.
|
"""Map FX value to TorchScript value.
|
||||||
|
|
||||||
|
|
@ -780,309 +667,6 @@ def export_after_normalizing_args_and_kwargs(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@_beartype.beartype
|
|
||||||
def _move_placeholder_to_front(graph_module: "torch.fx.GraphModule") -> None:
|
|
||||||
"""
|
|
||||||
This function move all placeholder nodes to the front of the graph node list.
|
|
||||||
In torch.fx.Graph, placeholder is a special assignment node. If it's not
|
|
||||||
executed in the beginning, it could overwrite values computed by upstream
|
|
||||||
nodes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
graph = graph_module.graph
|
|
||||||
placeholders = []
|
|
||||||
first_not_placeholder = None
|
|
||||||
for node in graph.nodes:
|
|
||||||
if node.op == "placeholder":
|
|
||||||
placeholders.append(node)
|
|
||||||
if first_not_placeholder is None and node.op != "placeholder":
|
|
||||||
first_not_placeholder = node
|
|
||||||
if first_not_placeholder is None:
|
|
||||||
return
|
|
||||||
for placeholder in placeholders:
|
|
||||||
first_not_placeholder.prepend(placeholder)
|
|
||||||
|
|
||||||
|
|
||||||
@_beartype.beartype
|
|
||||||
def _replace_get_attr_with_placeholder(
|
|
||||||
graph_module: "torch.fx.GraphModule",
|
|
||||||
) -> Tuple[torch.Tensor, ...]:
|
|
||||||
"""
|
|
||||||
Replace get_attr with placeholder.
|
|
||||||
The parameters and buffers accessed by the original get_attr are returned;
|
|
||||||
they are useful when creating random inputs for the modified graph_module.
|
|
||||||
"""
|
|
||||||
graph = graph_module.graph
|
|
||||||
replaced_attrs: List[torch.Tensor] = []
|
|
||||||
for node in graph.nodes:
|
|
||||||
if node.op == "get_attr":
|
|
||||||
replaced_attr: Optional[torch.Tensor] = None
|
|
||||||
# get_attr could retrieve either parameter or buffer, so
|
|
||||||
# we need to try both.
|
|
||||||
try:
|
|
||||||
replaced_attr = graph_module.get_parameter(node.target)
|
|
||||||
except AttributeError:
|
|
||||||
# It's possible that model author use buffer instead of
|
|
||||||
# parameter to store trainable weights. In this case,
|
|
||||||
# 1. get_parameter will throw something like
|
|
||||||
# AttributeError: `bias` is not an nn.Parameter.
|
|
||||||
# 2. get_buffer should work.
|
|
||||||
replaced_attr = graph_module.get_buffer(node.target)
|
|
||||||
|
|
||||||
# Reassign op type so that get_attr node becomes placeholder node.
|
|
||||||
node.op = "placeholder"
|
|
||||||
# The target name in placeholder must be a valid Python identifier.
|
|
||||||
# Thus, we replace, e.g., "module.submodule.weight" with
|
|
||||||
# "module_submodule_weight".
|
|
||||||
node.target = node.target.replace(".", "_")
|
|
||||||
# Default value is None. This is needed as long as the "graph_module"
|
|
||||||
# has optional inputs. Assume the original forward signature is
|
|
||||||
# def forward(self, x, y=None)
|
|
||||||
# and the replaced get_attr node has target "z". Then, the modified
|
|
||||||
# signature should be
|
|
||||||
# def forward(self, x, y=None, z=None)
|
|
||||||
# Without the following line, the signature will be
|
|
||||||
# def forward(self, x, y=None, z)
|
|
||||||
# , which is not valid Python code.
|
|
||||||
node.args = (None,)
|
|
||||||
|
|
||||||
replaced_attrs.append(replaced_attr)
|
|
||||||
|
|
||||||
return tuple(replaced_attrs)
|
|
||||||
|
|
||||||
|
|
||||||
@_beartype.beartype
|
|
||||||
def _trace_into_fx_graph_via_fx_symbolic_trace(
|
|
||||||
module: torch.nn.Module,
|
|
||||||
*args,
|
|
||||||
# kwargs are the keyword arguments to call "module"; that is,
|
|
||||||
# module(*args, **kwargs) must run.
|
|
||||||
**kwargs,
|
|
||||||
) -> Tuple["torch.fx.GraphModule", Tuple[Any, ...]]:
|
|
||||||
signature = inspect.signature(module.forward)
|
|
||||||
|
|
||||||
# We hope the input kwargs will be mapped to bound.args after binding.
|
|
||||||
# If not, we will raise an error.
|
|
||||||
bound = signature.bind(*args, **kwargs)
|
|
||||||
bound.apply_defaults()
|
|
||||||
# After apply_defaults, all non keyword-only arguments are in bound.args.
|
|
||||||
# Because below code do not support keyword-word arguments, bound.kwargs
|
|
||||||
# must be empty.
|
|
||||||
assert len(bound.kwargs) == 0, bound.kwargs
|
|
||||||
|
|
||||||
# Create inputs to call symbolic trace (torch.fx.symbolic_trace)
|
|
||||||
# Example content of concrete_args:
|
|
||||||
# concrete_args["x"] = torch.fx._symbolic_trace.PH
|
|
||||||
# concrete_args["b"] = 1
|
|
||||||
# where "x" and "b" are argument names in "signature".
|
|
||||||
concrete_args = {}
|
|
||||||
for param_name, param_value in bound.arguments.items():
|
|
||||||
if isinstance(param_value, torch.Tensor):
|
|
||||||
# param_value can be, e.g., a real tensor or a fake tensor.
|
|
||||||
# param_value is treated as substitutable tensor symbol (aka placeholder).
|
|
||||||
concrete_args[param_name] = torch.fx._symbolic_trace.PH
|
|
||||||
else:
|
|
||||||
concrete_args[param_name] = param_value
|
|
||||||
|
|
||||||
return (
|
|
||||||
_module_expansion_symbolic_trace(module, concrete_args=concrete_args),
|
|
||||||
bound.args,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@_beartype.beartype
|
|
||||||
def export_without_parameters_and_buffers(
|
|
||||||
module: torch.nn.Module,
|
|
||||||
*args,
|
|
||||||
decomposition_table: Optional[Dict[torch._ops.OpOverload, Callable]] = None,
|
|
||||||
use_binary_format: bool = True,
|
|
||||||
opset_version: int = _constants.ONNX_DEFAULT_OPSET,
|
|
||||||
op_level_debug: bool = False,
|
|
||||||
# kwargs are the keyword arguments to call "module"; that is,
|
|
||||||
# module(*args, **kwargs) must run.
|
|
||||||
**kwargs,
|
|
||||||
) -> Tuple[
|
|
||||||
Union["onnx.ModelProto", bytes],
|
|
||||||
"torch.fx.GraphModule",
|
|
||||||
Tuple[Any, ...],
|
|
||||||
Tuple[Any, ...],
|
|
||||||
]:
|
|
||||||
|
|
||||||
graph_module, bound_args = _trace_into_fx_graph_via_fx_symbolic_trace(
|
|
||||||
module, *args, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
# Make sure all placeholder nodes are executed before get_attr nodes.
|
|
||||||
# Otherwise, inputs can interleave with initializers in the final ModeoProto.graph.input.
|
|
||||||
# Basically, we want
|
|
||||||
# ModeoProto.graph.input =
|
|
||||||
# [input_0, input_1, ..., input_n, weight_0, weight_1, ..., weight_m]
|
|
||||||
# and we don't want
|
|
||||||
# ModeoProto.graph.input =
|
|
||||||
# [input_0, weight_0, input_1, weight_1, ..., input_n, weight_0, weight_1, ..., weight_m]
|
|
||||||
_move_placeholder_to_front(graph_module)
|
|
||||||
# To save memory, move get_attr to input so that the generated model doesn't
|
|
||||||
# have weigh tensors. "replaced_attrs" are the list of replaced weight tensors.
|
|
||||||
replaced_attrs = _replace_get_attr_with_placeholder(graph_module)
|
|
||||||
# Move all newly created placeholder nodes to the front of the graph.
|
|
||||||
_move_placeholder_to_front(graph_module)
|
|
||||||
# Finalize the graph editing.
|
|
||||||
graph_module.recompile()
|
|
||||||
return (
|
|
||||||
_export(
|
|
||||||
graph_module,
|
|
||||||
(*bound_args, *replaced_attrs),
|
|
||||||
opset_version=opset_version,
|
|
||||||
decomposition_table=decomposition_table,
|
|
||||||
use_binary_format=use_binary_format,
|
|
||||||
op_level_debug=op_level_debug,
|
|
||||||
),
|
|
||||||
graph_module,
|
|
||||||
bound_args,
|
|
||||||
replaced_attrs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@_beartype.beartype
|
|
||||||
def _create_tensor_proto_with_external_data(
|
|
||||||
tensor: torch.Tensor, name: str, location: str, basepath: str
|
|
||||||
) -> "onnx.TensorProto":
|
|
||||||
"""Create a TensorProto with external data from a PyTorch tensor.
|
|
||||||
The external data is saved to os.path.join(basepath, location).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor: Tensor to be saved.
|
|
||||||
name: Name of the tensor (i.e., initializer name in ONNX graph).
|
|
||||||
location: Relative location of the external data file
|
|
||||||
(e.g., "/tmp/initializers/weight_0" when model is "/tmp/model_name.onnx").
|
|
||||||
basepath: Base path of the external data file (e.g., "/tmp/external_data" while model must be in "/tmp").
|
|
||||||
|
|
||||||
|
|
||||||
Reference for ONNX's external data format:
|
|
||||||
How to load?
|
|
||||||
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L187
|
|
||||||
How to save?
|
|
||||||
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L43
|
|
||||||
How to set ONNX fields?
|
|
||||||
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L88
|
|
||||||
"""
|
|
||||||
tensor_proto = onnx.TensorProto()
|
|
||||||
tensor_proto.name = name
|
|
||||||
tensor_proto.data_type = torch.onnx._type_utils._SCALAR_TYPE_TO_ONNX[ # type: ignore[assignment]
|
|
||||||
torch.onnx._type_utils._DTYPE_TO_SCALAR_TYPE[tensor.dtype]
|
|
||||||
]
|
|
||||||
tensor_proto.dims.extend(tensor.shape)
|
|
||||||
tensor_proto.data_location = onnx.TensorProto.EXTERNAL
|
|
||||||
|
|
||||||
# Settings for saving one tensor per file.
|
|
||||||
# Offset is zero because there is no other tensor in the same file.
|
|
||||||
key_value_pairs = {
|
|
||||||
"location": location,
|
|
||||||
"offset": 0,
|
|
||||||
"length": tensor.untyped_storage().nbytes(),
|
|
||||||
}
|
|
||||||
for k, v in key_value_pairs.items():
|
|
||||||
entry = tensor_proto.external_data.add()
|
|
||||||
entry.key = k
|
|
||||||
entry.value = str(v)
|
|
||||||
|
|
||||||
# Actual path to write content of tensor.
|
|
||||||
external_data_file_path = os.path.join(basepath, location)
|
|
||||||
if os.path.exists(external_data_file_path):
|
|
||||||
os.remove(external_data_file_path)
|
|
||||||
|
|
||||||
# Create external data's folder if not exists.
|
|
||||||
external_data_dir_path = os.path.dirname(external_data_file_path)
|
|
||||||
if not os.path.exists(external_data_dir_path):
|
|
||||||
# if the demo_folder directory is not present
|
|
||||||
# then create it.
|
|
||||||
os.makedirs(external_data_dir_path)
|
|
||||||
|
|
||||||
# Create a fresh file.
|
|
||||||
with open(external_data_file_path, "xb") as data_file:
|
|
||||||
# No need to call "seek" because offset is 0.
|
|
||||||
# data_file.seek(0)
|
|
||||||
# Write tensor content to the file.
|
|
||||||
data_file.write(tensor.numpy().tobytes())
|
|
||||||
|
|
||||||
return tensor_proto
|
|
||||||
|
|
||||||
|
|
||||||
@_beartype.beartype
|
|
||||||
def save_model_with_external_data(
|
|
||||||
basepath: str,
|
|
||||||
model_location: str,
|
|
||||||
initializer_location: str,
|
|
||||||
torch_load_paths: Tuple[str, ...],
|
|
||||||
onnx_model: "onnx.ModelProto",
|
|
||||||
) -> None:
|
|
||||||
"""Load PyTorch tensors from files and add to "onnx_model" as external initializers.
|
|
||||||
|
|
||||||
Output files:
|
|
||||||
ONNX model file path:
|
|
||||||
ONNX initializer folder: os.path.join(basepath, initializer_location)
|
|
||||||
|
|
||||||
After running this function, you can do
|
|
||||||
ort_sess = onnxruntime.InferenceSession(os.path.join(basepath, model_location))
|
|
||||||
to execute the model.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
basepath: Base path of the external data file (e.g., "/tmp/large-onnx-model").
|
|
||||||
model_location: Relative location of the ONNX model file.
|
|
||||||
E.g., "model.onnx" so that the model file is saved to
|
|
||||||
"/tmp/large-onnx-model/model.onnx".
|
|
||||||
initializer_location: Relative location of the ONNX initializer folder.
|
|
||||||
E.g., "initializers" so that the initializers are saved to
|
|
||||||
"/tmp/large-onnx-model/initializers".
|
|
||||||
torch_load_paths: Files which containing serialized PyTorch tensors to be saved
|
|
||||||
as ONNX initializers. They are loaded by torch.load.
|
|
||||||
onnx_model: ONNX model to be saved with external initializers.
|
|
||||||
If an input name matches a tensor loaded from "torch_load_paths",
|
|
||||||
the tensor will be saved as that input's external initializer.
|
|
||||||
"""
|
|
||||||
onnx_model_with_initializers = onnx.ModelProto()
|
|
||||||
onnx_model_with_initializers.CopyFrom(onnx_model)
|
|
||||||
onnx_input_names = [input.name for input in onnx_model.graph.input]
|
|
||||||
|
|
||||||
for path in torch_load_paths:
|
|
||||||
state_ditc = torch.load(path)
|
|
||||||
for name, tensor in state_ditc.items():
|
|
||||||
# Basically, "transformer.attention.self.query.weight" is mapped
|
|
||||||
# to "transformer_attention_self_query_weight" for mimicking the
|
|
||||||
# name-modifying code in FX-to-ONNX exporter.
|
|
||||||
# See function _replace_get_attr_with_placeholder for details.
|
|
||||||
refined_name = name.replace(".", "_")
|
|
||||||
|
|
||||||
# For each refined PyTorch tensor name loaded by torch.load,
|
|
||||||
# 1. Search its best match in ONNX model. E.g., the match of
|
|
||||||
# "transformer_attention_weight" could be "attention_weight".
|
|
||||||
# 2. Set "tensor" as the initializer of the matched ONNX input.
|
|
||||||
# E.g., "tensor" is stored as the initializer of "attention_weight".
|
|
||||||
# Step 1 is required because sometimes, tensor names are stored with prefix the dictionary
|
|
||||||
# loaded by torch.load.
|
|
||||||
for onnx_input_name in onnx_input_names:
|
|
||||||
if onnx_input_name.endswith(refined_name) or refined_name.endswith(
|
|
||||||
onnx_input_name
|
|
||||||
):
|
|
||||||
# Find a match. Change refined_name to the matched ONNX input name, so that we
|
|
||||||
# create initializer with the right ONNX name.
|
|
||||||
refined_name = onnx_input_name
|
|
||||||
break
|
|
||||||
|
|
||||||
relative_tensor_file_path = os.path.join(initializer_location, refined_name)
|
|
||||||
# Create one file per tensor.
|
|
||||||
# tensor_proto.raw_data is stored to external file at
|
|
||||||
# os.path.join(basepath, relative_tensor_file_path).
|
|
||||||
tensor_proto = _create_tensor_proto_with_external_data(
|
|
||||||
tensor, refined_name, relative_tensor_file_path, basepath
|
|
||||||
)
|
|
||||||
# Add the tensor_proto to the ONNX model as an initializer with external data.
|
|
||||||
onnx_model_with_initializers.graph.initializer.append(tensor_proto)
|
|
||||||
|
|
||||||
# model_location should be a pure file name such as "file_name.onnx", not "folder/file_name.onnx".
|
|
||||||
onnx.save(onnx_model_with_initializers, os.path.join(basepath, model_location))
|
|
||||||
|
|
||||||
|
|
||||||
@_beartype.beartype
|
@_beartype.beartype
|
||||||
def _validate_op_between_ort_torch(
|
def _validate_op_between_ort_torch(
|
||||||
node: torch.fx.Node,
|
node: torch.fx.Node,
|
||||||
|
|
|
||||||
149
torch/onnx/_internal/fx/serialization.py
Normal file
149
torch/onnx/_internal/fx/serialization.py
Normal file
|
|
@ -0,0 +1,149 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.onnx._internal import _beartype
|
||||||
|
|
||||||
|
|
||||||
|
@_beartype.beartype
|
||||||
|
def _create_tensor_proto_with_external_data(
|
||||||
|
tensor: torch.Tensor, name: str, location: str, basepath: str
|
||||||
|
) -> "onnx.TensorProto":
|
||||||
|
"""Create a TensorProto with external data from a PyTorch tensor.
|
||||||
|
The external data is saved to os.path.join(basepath, location).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: Tensor to be saved.
|
||||||
|
name: Name of the tensor (i.e., initializer name in ONNX graph).
|
||||||
|
location: Relative location of the external data file
|
||||||
|
(e.g., "/tmp/initializers/weight_0" when model is "/tmp/model_name.onnx").
|
||||||
|
basepath: Base path of the external data file (e.g., "/tmp/external_data" while model must be in "/tmp").
|
||||||
|
|
||||||
|
|
||||||
|
Reference for ONNX's external data format:
|
||||||
|
How to load?
|
||||||
|
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L187
|
||||||
|
How to save?
|
||||||
|
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L43
|
||||||
|
How to set ONNX fields?
|
||||||
|
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L88
|
||||||
|
"""
|
||||||
|
tensor_proto = onnx.TensorProto()
|
||||||
|
tensor_proto.name = name
|
||||||
|
tensor_proto.data_type = torch.onnx._type_utils._SCALAR_TYPE_TO_ONNX[ # type: ignore[assignment]
|
||||||
|
torch.onnx._type_utils._DTYPE_TO_SCALAR_TYPE[tensor.dtype]
|
||||||
|
]
|
||||||
|
tensor_proto.dims.extend(tensor.shape)
|
||||||
|
tensor_proto.data_location = onnx.TensorProto.EXTERNAL
|
||||||
|
|
||||||
|
# Settings for saving one tensor per file.
|
||||||
|
# Offset is zero because there is no other tensor in the same file.
|
||||||
|
key_value_pairs = {
|
||||||
|
"location": location,
|
||||||
|
"offset": 0,
|
||||||
|
"length": tensor.untyped_storage().nbytes(),
|
||||||
|
}
|
||||||
|
for k, v in key_value_pairs.items():
|
||||||
|
entry = tensor_proto.external_data.add()
|
||||||
|
entry.key = k
|
||||||
|
entry.value = str(v)
|
||||||
|
|
||||||
|
# Actual path to write content of tensor.
|
||||||
|
external_data_file_path = os.path.join(basepath, location)
|
||||||
|
if os.path.exists(external_data_file_path):
|
||||||
|
os.remove(external_data_file_path)
|
||||||
|
|
||||||
|
# Create external data's folder if not exists.
|
||||||
|
external_data_dir_path = os.path.dirname(external_data_file_path)
|
||||||
|
if not os.path.exists(external_data_dir_path):
|
||||||
|
# if the demo_folder directory is not present
|
||||||
|
# then create it.
|
||||||
|
os.makedirs(external_data_dir_path)
|
||||||
|
|
||||||
|
# Create a fresh file.
|
||||||
|
with open(external_data_file_path, "xb") as data_file:
|
||||||
|
# No need to call "seek" because offset is 0.
|
||||||
|
# data_file.seek(0)
|
||||||
|
# Write tensor content to the file.
|
||||||
|
data_file.write(tensor.numpy().tobytes())
|
||||||
|
|
||||||
|
return tensor_proto
|
||||||
|
|
||||||
|
|
||||||
|
@_beartype.beartype
|
||||||
|
def save_model_with_external_data(
|
||||||
|
basepath: str,
|
||||||
|
model_location: str,
|
||||||
|
initializer_location: str,
|
||||||
|
torch_load_paths: Tuple[str, ...],
|
||||||
|
onnx_model: onnx.ModelProto,
|
||||||
|
) -> None:
|
||||||
|
"""Load PyTorch tensors from files and add to "onnx_model" as external initializers.
|
||||||
|
|
||||||
|
Output files:
|
||||||
|
ONNX model file path:
|
||||||
|
ONNX initializer folder: os.path.join(basepath, initializer_location)
|
||||||
|
|
||||||
|
After running this function, you can do
|
||||||
|
ort_sess = onnxruntime.InferenceSession(os.path.join(basepath, model_location))
|
||||||
|
to execute the model.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
basepath: Base path of the external data file (e.g., "/tmp/large-onnx-model").
|
||||||
|
model_location: Relative location of the ONNX model file.
|
||||||
|
E.g., "model.onnx" so that the model file is saved to
|
||||||
|
"/tmp/large-onnx-model/model.onnx".
|
||||||
|
initializer_location: Relative location of the ONNX initializer folder.
|
||||||
|
E.g., "initializers" so that the initializers are saved to
|
||||||
|
"/tmp/large-onnx-model/initializers".
|
||||||
|
torch_load_paths: Files which containing serialized PyTorch tensors to be saved
|
||||||
|
as ONNX initializers. They are loaded by torch.load.
|
||||||
|
onnx_model: ONNX model to be saved with external initializers.
|
||||||
|
If an input name matches a tensor loaded from "torch_load_paths",
|
||||||
|
the tensor will be saved as that input's external initializer.
|
||||||
|
"""
|
||||||
|
onnx_model_with_initializers = onnx.ModelProto()
|
||||||
|
onnx_model_with_initializers.CopyFrom(onnx_model)
|
||||||
|
onnx_input_names = [input.name for input in onnx_model.graph.input]
|
||||||
|
|
||||||
|
for path in torch_load_paths:
|
||||||
|
state_ditc = torch.load(path)
|
||||||
|
for name, tensor in state_ditc.items():
|
||||||
|
# Basically, "transformer.attention.self.query.weight" is mapped
|
||||||
|
# to "transformer_attention_self_query_weight" for mimicking the
|
||||||
|
# name-modifying code in FX-to-ONNX exporter.
|
||||||
|
# See function _replace_get_attr_with_placeholder for details.
|
||||||
|
refined_name = name.replace(".", "_")
|
||||||
|
|
||||||
|
# For each refined PyTorch tensor name loaded by torch.load,
|
||||||
|
# 1. Search its best match in ONNX model. E.g., the match of
|
||||||
|
# "transformer_attention_weight" could be "attention_weight".
|
||||||
|
# 2. Set "tensor" as the initializer of the matched ONNX input.
|
||||||
|
# E.g., "tensor" is stored as the initializer of "attention_weight".
|
||||||
|
# Step 1 is required because sometimes, tensor names are stored with prefix the dictionary
|
||||||
|
# loaded by torch.load.
|
||||||
|
for onnx_input_name in onnx_input_names:
|
||||||
|
if onnx_input_name.endswith(refined_name) or refined_name.endswith(
|
||||||
|
onnx_input_name
|
||||||
|
):
|
||||||
|
# Find a match. Change refined_name to the matched ONNX input name, so that we
|
||||||
|
# create initializer with the right ONNX name.
|
||||||
|
refined_name = onnx_input_name
|
||||||
|
break
|
||||||
|
|
||||||
|
relative_tensor_file_path = os.path.join(initializer_location, refined_name)
|
||||||
|
# Create one file per tensor.
|
||||||
|
# tensor_proto.raw_data is stored to external file at
|
||||||
|
# os.path.join(basepath, relative_tensor_file_path).
|
||||||
|
tensor_proto = _create_tensor_proto_with_external_data(
|
||||||
|
tensor, refined_name, relative_tensor_file_path, basepath
|
||||||
|
)
|
||||||
|
# Add the tensor_proto to the ONNX model as an initializer with external data.
|
||||||
|
onnx_model_with_initializers.graph.initializer.append(tensor_proto)
|
||||||
|
|
||||||
|
# model_location should be a pure file name such as "file_name.onnx", not "folder/file_name.onnx".
|
||||||
|
onnx.save(onnx_model_with_initializers, os.path.join(basepath, model_location))
|
||||||
289
torch/onnx/_internal/fx/symbolic_exporter.py
Normal file
289
torch/onnx/_internal/fx/symbolic_exporter.py
Normal file
|
|
@ -0,0 +1,289 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import functools
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.fx
|
||||||
|
|
||||||
|
from torch.onnx import _constants
|
||||||
|
from torch.onnx._internal import _beartype
|
||||||
|
from torch.onnx._internal.fx import exporter
|
||||||
|
|
||||||
|
# Functions directly wrapped to produce torch.fx.Proxy so that symbolic
|
||||||
|
# data can flow through those functions. Python functions (e.g., `torch.arange`)
|
||||||
|
# not defined by pybind11 in C++ do not go though Python dispatcher, so
|
||||||
|
# they are not automatically patched by FX's Python dispatcher.
|
||||||
|
# The list below means `torch.arange`, `torch.tensor`, and so on will be
|
||||||
|
# patched.
|
||||||
|
_TORCH_METHODS_TO_PATCH: Tuple[str, ...] = (
|
||||||
|
"arange",
|
||||||
|
"tensor",
|
||||||
|
"finfo",
|
||||||
|
"full",
|
||||||
|
"empty",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleExpansionTracer(torch.fx._symbolic_trace.Tracer):
|
||||||
|
"""Tracer to create ONNX-exporting friendly FX graph.
|
||||||
|
|
||||||
|
This tracer traces models into operators. That is,
|
||||||
|
the traced graph mostly contains call_function nodes and
|
||||||
|
has no call_module nodes. The call_module nodes
|
||||||
|
are problematic to the use of make_fx(...) in ONNX
|
||||||
|
exporter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@_beartype.beartype
|
||||||
|
def is_leaf_module(
|
||||||
|
self, module: torch.nn.Module, module_qualified_name: str
|
||||||
|
) -> bool:
|
||||||
|
# This returns False so that all sub-modules are considered as not leaves
|
||||||
|
# and therefore expanded into operators in
|
||||||
|
# torch.fx._symbolic_trace.Tracer.call_module.
|
||||||
|
return False
|
||||||
|
|
||||||
|
@_beartype.beartype
|
||||||
|
def to_bool(self, obj: "torch.fx.Proxy") -> bool:
|
||||||
|
# FIXME: This is a hack to tracing through if-else Python blocks.
|
||||||
|
# It may generate incorrect ONNX graphs if the if-else block
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@_beartype.beartype
|
||||||
|
def _move_placeholder_to_front(graph_module: torch.fx.GraphModule) -> None:
|
||||||
|
"""
|
||||||
|
This function move all placeholder nodes to the front of the graph node list.
|
||||||
|
In torch.fx.Graph, placeholder is a special assignment node. If it's not
|
||||||
|
executed in the beginning, it could overwrite values computed by upstream
|
||||||
|
nodes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
graph = graph_module.graph
|
||||||
|
placeholders = []
|
||||||
|
first_not_placeholder = None
|
||||||
|
for node in graph.nodes:
|
||||||
|
if node.op == "placeholder":
|
||||||
|
placeholders.append(node)
|
||||||
|
if first_not_placeholder is None and node.op != "placeholder":
|
||||||
|
first_not_placeholder = node
|
||||||
|
if first_not_placeholder is None:
|
||||||
|
return
|
||||||
|
for placeholder in placeholders:
|
||||||
|
first_not_placeholder.prepend(placeholder)
|
||||||
|
|
||||||
|
|
||||||
|
@_beartype.beartype
|
||||||
|
def _replace_get_attr_with_placeholder(
|
||||||
|
graph_module: torch.fx.GraphModule,
|
||||||
|
) -> Tuple[torch.Tensor, ...]:
|
||||||
|
"""
|
||||||
|
Replace get_attr with placeholder.
|
||||||
|
The parameters and buffers accessed by the original get_attr are returned;
|
||||||
|
they are useful when creating random inputs for the modified graph_module.
|
||||||
|
"""
|
||||||
|
graph = graph_module.graph
|
||||||
|
replaced_attrs: List[torch.Tensor] = []
|
||||||
|
for node in graph.nodes:
|
||||||
|
if node.op == "get_attr":
|
||||||
|
replaced_attr: Optional[torch.Tensor] = None
|
||||||
|
# get_attr could retrieve either parameter or buffer, so
|
||||||
|
# we need to try both.
|
||||||
|
try:
|
||||||
|
replaced_attr = graph_module.get_parameter(node.target)
|
||||||
|
except AttributeError:
|
||||||
|
# It's possible that model author use buffer instead of
|
||||||
|
# parameter to store trainable weights. In this case,
|
||||||
|
# 1. get_parameter will throw something like
|
||||||
|
# AttributeError: `bias` is not an nn.Parameter.
|
||||||
|
# 2. get_buffer should work.
|
||||||
|
replaced_attr = graph_module.get_buffer(node.target)
|
||||||
|
|
||||||
|
# Reassign op type so that get_attr node becomes placeholder node.
|
||||||
|
node.op = "placeholder"
|
||||||
|
# The target name in placeholder must be a valid Python identifier.
|
||||||
|
# Thus, we replace, e.g., "module.submodule.weight" with
|
||||||
|
# "module_submodule_weight".
|
||||||
|
node.target = node.target.replace(".", "_")
|
||||||
|
# Default value is None. This is needed as long as the "graph_module"
|
||||||
|
# has optional inputs. Assume the original forward signature is
|
||||||
|
# def forward(self, x, y=None)
|
||||||
|
# and the replaced get_attr node has target "z". Then, the modified
|
||||||
|
# signature should be
|
||||||
|
# def forward(self, x, y=None, z=None)
|
||||||
|
# Without the following line, the signature will be
|
||||||
|
# def forward(self, x, y=None, z)
|
||||||
|
# , which is not valid Python code.
|
||||||
|
node.args = (None,)
|
||||||
|
|
||||||
|
replaced_attrs.append(replaced_attr)
|
||||||
|
|
||||||
|
return tuple(replaced_attrs)
|
||||||
|
|
||||||
|
|
||||||
|
@_beartype.beartype
|
||||||
|
def _trace_into_fx_graph_via_fx_symbolic_trace(
|
||||||
|
module: torch.nn.Module,
|
||||||
|
*args,
|
||||||
|
# kwargs are the keyword arguments to call "module"; that is,
|
||||||
|
# module(*args, **kwargs) must run.
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple["torch.fx.GraphModule", Tuple[Any, ...]]:
|
||||||
|
signature = inspect.signature(module.forward)
|
||||||
|
|
||||||
|
# We hope the input kwargs will be mapped to bound.args after binding.
|
||||||
|
# If not, we will raise an error.
|
||||||
|
bound = signature.bind(*args, **kwargs)
|
||||||
|
bound.apply_defaults()
|
||||||
|
# After apply_defaults, all non keyword-only arguments are in bound.args.
|
||||||
|
# Because below code do not support keyword-word arguments, bound.kwargs
|
||||||
|
# must be empty.
|
||||||
|
assert len(bound.kwargs) == 0, bound.kwargs
|
||||||
|
|
||||||
|
# Create inputs to call symbolic trace (torch.fx.symbolic_trace)
|
||||||
|
# Example content of concrete_args:
|
||||||
|
# concrete_args["x"] = torch.fx._symbolic_trace.PH
|
||||||
|
# concrete_args["b"] = 1
|
||||||
|
# where "x" and "b" are argument names in "signature".
|
||||||
|
concrete_args = {}
|
||||||
|
for param_name, param_value in bound.arguments.items():
|
||||||
|
if isinstance(param_value, torch.Tensor):
|
||||||
|
# param_value can be, e.g., a real tensor or a fake tensor.
|
||||||
|
# param_value is treated as substitutable tensor symbol (aka placeholder).
|
||||||
|
concrete_args[param_name] = torch.fx._symbolic_trace.PH
|
||||||
|
else:
|
||||||
|
concrete_args[param_name] = param_value
|
||||||
|
|
||||||
|
return (
|
||||||
|
_module_expansion_symbolic_trace(module, concrete_args=concrete_args),
|
||||||
|
bound.args,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _wrap_for_symbolic_trace(target: Callable) -> Tuple[Callable, Callable]:
|
||||||
|
"""This function wraps ```target`` for symbolic tracing.
|
||||||
|
|
||||||
|
This function wraps ```target``` so that its wrapper produces
|
||||||
|
torch.fx.Proxy in symbolic computation. The returned values are
|
||||||
|
the wrapper and then the original function. Per `_TORCH_METHODS_TO_PATCH`,
|
||||||
|
this function shall receive `torch.arange`, `torch.tensor`, etc. as inputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@functools.wraps(target)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
proxy = None
|
||||||
|
|
||||||
|
def check_has_proxy(v):
|
||||||
|
if isinstance(v, torch.fx.Proxy):
|
||||||
|
nonlocal proxy
|
||||||
|
proxy = v
|
||||||
|
|
||||||
|
torch.fx.node.map_aggregate(args, check_has_proxy)
|
||||||
|
torch.fx.node.map_aggregate(kwargs, check_has_proxy)
|
||||||
|
|
||||||
|
if proxy is not None:
|
||||||
|
return proxy.tracer.create_proxy("call_function", target, args, kwargs)
|
||||||
|
else:
|
||||||
|
return target(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper, target
|
||||||
|
|
||||||
|
|
||||||
|
@_beartype.beartype
|
||||||
|
def _module_expansion_symbolic_trace(
|
||||||
|
root: Union[torch.nn.Module, Callable[..., Any]],
|
||||||
|
concrete_args: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> torch.fx.GraphModule:
|
||||||
|
"""Trace a callable into FX graph.
|
||||||
|
|
||||||
|
When "root" is torch.nn.Module, calls to its submodule (type: torch.nn.Module) will be
|
||||||
|
expanded into operators (e.g., torch.matmul, torch.add, +, and -) to simplify graph
|
||||||
|
structure.
|
||||||
|
"""
|
||||||
|
# For functions doesn't support symbolic tracing, create wrappers
|
||||||
|
# which produce symbolic results during tracing.
|
||||||
|
patched_torch_methods = {
|
||||||
|
target_name: _wrap_for_symbolic_trace(getattr(torch, target_name))
|
||||||
|
for target_name in _TORCH_METHODS_TO_PATCH
|
||||||
|
}
|
||||||
|
|
||||||
|
# Set the symbolic-tracing friendly functions so that `tracer.trace` below
|
||||||
|
# can work.
|
||||||
|
for name, (wrapper, _) in patched_torch_methods.items():
|
||||||
|
setattr(torch, name, wrapper)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Set up a tracer.
|
||||||
|
tracer = ModuleExpansionTracer()
|
||||||
|
# Trace the model.
|
||||||
|
graph = tracer.trace(root, concrete_args)
|
||||||
|
name = (
|
||||||
|
root.__class__.__name__
|
||||||
|
if isinstance(root, torch.nn.Module)
|
||||||
|
else root.__name__
|
||||||
|
)
|
||||||
|
return torch.fx.GraphModule(tracer.root, graph, name)
|
||||||
|
finally:
|
||||||
|
# Revert the patches for symbolic tracing.
|
||||||
|
for name, (_, wrapped) in patched_torch_methods.items():
|
||||||
|
# wrapped is the original version of `torch.name`.
|
||||||
|
setattr(torch, name, wrapped)
|
||||||
|
|
||||||
|
|
||||||
|
@_beartype.beartype
|
||||||
|
def export_without_parameters_and_buffers(
|
||||||
|
module: torch.nn.Module,
|
||||||
|
*args,
|
||||||
|
decomposition_table: Optional[Dict[torch._ops.OpOverload, Callable]] = None,
|
||||||
|
use_binary_format: bool = True,
|
||||||
|
opset_version: int = _constants.ONNX_DEFAULT_OPSET,
|
||||||
|
op_level_debug: bool = False,
|
||||||
|
# kwargs are the keyword arguments to call "module"; that is,
|
||||||
|
# module(*args, **kwargs) must run.
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[
|
||||||
|
Union[onnx.ModelProto, bytes],
|
||||||
|
torch.fx.GraphModule,
|
||||||
|
Tuple[Any, ...],
|
||||||
|
Tuple[Any, ...],
|
||||||
|
]:
|
||||||
|
|
||||||
|
graph_module, bound_args = _trace_into_fx_graph_via_fx_symbolic_trace(
|
||||||
|
module, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make sure all placeholder nodes are executed before get_attr nodes.
|
||||||
|
# Otherwise, inputs can interleave with initializers in the final ModeoProto.graph.input.
|
||||||
|
# Basically, we want
|
||||||
|
# ModeoProto.graph.input =
|
||||||
|
# [input_0, input_1, ..., input_n, weight_0, weight_1, ..., weight_m]
|
||||||
|
# and we don't want
|
||||||
|
# ModeoProto.graph.input =
|
||||||
|
# [input_0, weight_0, input_1, weight_1, ..., input_n, weight_0, weight_1, ..., weight_m]
|
||||||
|
_move_placeholder_to_front(graph_module)
|
||||||
|
# To save memory, move get_attr to input so that the generated model doesn't
|
||||||
|
# have weigh tensors. "replaced_attrs" are the list of replaced weight tensors.
|
||||||
|
replaced_attrs = _replace_get_attr_with_placeholder(graph_module)
|
||||||
|
# Move all newly created placeholder nodes to the front of the graph.
|
||||||
|
_move_placeholder_to_front(graph_module)
|
||||||
|
# Finalize the graph editing.
|
||||||
|
graph_module.recompile()
|
||||||
|
return (
|
||||||
|
exporter._export(
|
||||||
|
graph_module,
|
||||||
|
(*bound_args, *replaced_attrs),
|
||||||
|
opset_version=opset_version,
|
||||||
|
decomposition_table=decomposition_table,
|
||||||
|
use_binary_format=use_binary_format,
|
||||||
|
op_level_debug=op_level_debug,
|
||||||
|
),
|
||||||
|
graph_module,
|
||||||
|
bound_args,
|
||||||
|
replaced_attrs,
|
||||||
|
)
|
||||||
Loading…
Reference in New Issue
Block a user