[ONNX] Move symbolic export to separate file (#95650)

Move things around in the effort of preparing to refactor
the code structure.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95650
Approved by: https://github.com/titaiwangms
This commit is contained in:
BowenBao 2023-03-07 10:31:37 -08:00 committed by PyTorch MergeBot
parent d06729746c
commit 82dba844bb
4 changed files with 441 additions and 422 deletions

View File

@ -1,10 +1,7 @@
from .context import FxToOnnxContext from .context import FxToOnnxContext
from .exporter import ( from .exporter import export, export_after_normalizing_args_and_kwargs
export, from .serialization import save_model_with_external_data
export_after_normalizing_args_and_kwargs, from .symbolic_exporter import export_without_parameters_and_buffers
export_without_parameters_and_buffers,
save_model_with_external_data,
)
__all__ = [ __all__ = [

View File

@ -1,11 +1,9 @@
from __future__ import annotations from __future__ import annotations
import copy import copy
import functools
import inspect import inspect
import itertools import itertools
import operator import operator
import os
import re import re
import warnings import warnings
from types import FunctionType from types import FunctionType
@ -75,117 +73,6 @@ onnxscript.OnnxFunction.__call__ = _diagnose_onnx_function(
) )
class ModuleExpansionTracer(torch.fx._symbolic_trace.Tracer):
"""Tracer to create ONNX-exporting friendly FX graph.
This tracer traces models into operators. That is,
the traced graph mostly contains call_function nodes and
has no call_module nodes. The call_module nodes
are problematic to the use of make_fx(...) in ONNX
exporter.
"""
@_beartype.beartype
def is_leaf_module(
self, module: torch.nn.Module, module_qualified_name: str
) -> bool:
# This returns False so that all sub-modules are considered as not leaves
# and therefore expanded into operators in
# torch.fx._symbolic_trace.Tracer.call_module.
return False
@_beartype.beartype
def to_bool(self, obj: "torch.fx.Proxy") -> bool:
# This is a hack to tracing through if-else Python blocks.
# It may generate incorrect ONNX graphs if the if-else block
return False
# Functions directly wrapped to produce torch.fx.Proxy so that symbolic
# data can flow through those functions. Python functions (e.g., `torch.arange`)
# not defined by pybind11 in C++ do not go though Python dispatcher, so
# they are not automatically patched by FX's Python dispatcher.
# The list below means `torch.arange`, `torch.tensor`, and so on will be
# patched.
_TORCH_METHODS_TO_PATCH: Tuple[str, ...] = (
"arange",
"tensor",
"finfo",
"full",
"empty",
)
def _wrap_for_symbolic_trace(target: Callable) -> Tuple[Callable, Callable]:
"""This function wraps ```target`` for symbolic tracing.
This function wraps ```target``` so that its wrapper produces
torch.fx.Proxy in symbolic computation. The returned values are
the wrapper and then the original function. Per `_TORCH_METHODS_TO_PATCH`,
this function shall receive `torch.arange`, `torch.tensor`, etc. as inputs.
"""
@functools.wraps(target)
def wrapper(*args, **kwargs):
proxy = None
def check_has_proxy(v):
if isinstance(v, torch.fx.Proxy):
nonlocal proxy
proxy = v
torch.fx.node.map_aggregate(args, check_has_proxy)
torch.fx.node.map_aggregate(kwargs, check_has_proxy)
if proxy is not None:
return proxy.tracer.create_proxy("call_function", target, args, kwargs)
else:
return target(*args, **kwargs)
return wrapper, target
@_beartype.beartype
def _module_expansion_symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
) -> "torch.fx.GraphModule":
"""Trace a callable into FX graph.
When "root" is torch.nn.Module, calls to its submodule (type: torch.nn.Module) will be
expanded into operators (e.g., torch.matmul, torch.add, +, and -) to simplify graph
structure.
"""
# For functions doesn't support symbolic tracing, create wrappers
# which produce symbolic results during tracing.
patched_torch_methods = {
target_name: _wrap_for_symbolic_trace(getattr(torch, target_name))
for target_name in _TORCH_METHODS_TO_PATCH
}
# Set the symbolic-tracing friendly functions so that `tracer.trace` below
# can work.
for name, (wrapper, _) in patched_torch_methods.items():
setattr(torch, name, wrapper)
try:
# Set up a tracer.
tracer = ModuleExpansionTracer()
# Trace the model.
graph = tracer.trace(root, concrete_args)
name = (
root.__class__.__name__
if isinstance(root, torch.nn.Module)
else root.__name__
)
return torch.fx.GraphModule(tracer.root, graph, name)
finally:
# Revert the patches for symbolic tracing.
for name, (_, wrapped) in patched_torch_methods.items():
# wrapped is the original version of `torch.name`.
setattr(torch, name, wrapped)
def _retrieve_or_adapt_input_to_graph_set(fx_node_arg, fx_name_to_onnxscipt_value): def _retrieve_or_adapt_input_to_graph_set(fx_node_arg, fx_name_to_onnxscipt_value):
"""Map FX value to TorchScript value. """Map FX value to TorchScript value.
@ -780,309 +667,6 @@ def export_after_normalizing_args_and_kwargs(
) )
@_beartype.beartype
def _move_placeholder_to_front(graph_module: "torch.fx.GraphModule") -> None:
"""
This function move all placeholder nodes to the front of the graph node list.
In torch.fx.Graph, placeholder is a special assignment node. If it's not
executed in the beginning, it could overwrite values computed by upstream
nodes.
"""
graph = graph_module.graph
placeholders = []
first_not_placeholder = None
for node in graph.nodes:
if node.op == "placeholder":
placeholders.append(node)
if first_not_placeholder is None and node.op != "placeholder":
first_not_placeholder = node
if first_not_placeholder is None:
return
for placeholder in placeholders:
first_not_placeholder.prepend(placeholder)
@_beartype.beartype
def _replace_get_attr_with_placeholder(
graph_module: "torch.fx.GraphModule",
) -> Tuple[torch.Tensor, ...]:
"""
Replace get_attr with placeholder.
The parameters and buffers accessed by the original get_attr are returned;
they are useful when creating random inputs for the modified graph_module.
"""
graph = graph_module.graph
replaced_attrs: List[torch.Tensor] = []
for node in graph.nodes:
if node.op == "get_attr":
replaced_attr: Optional[torch.Tensor] = None
# get_attr could retrieve either parameter or buffer, so
# we need to try both.
try:
replaced_attr = graph_module.get_parameter(node.target)
except AttributeError:
# It's possible that model author use buffer instead of
# parameter to store trainable weights. In this case,
# 1. get_parameter will throw something like
# AttributeError: `bias` is not an nn.Parameter.
# 2. get_buffer should work.
replaced_attr = graph_module.get_buffer(node.target)
# Reassign op type so that get_attr node becomes placeholder node.
node.op = "placeholder"
# The target name in placeholder must be a valid Python identifier.
# Thus, we replace, e.g., "module.submodule.weight" with
# "module_submodule_weight".
node.target = node.target.replace(".", "_")
# Default value is None. This is needed as long as the "graph_module"
# has optional inputs. Assume the original forward signature is
# def forward(self, x, y=None)
# and the replaced get_attr node has target "z". Then, the modified
# signature should be
# def forward(self, x, y=None, z=None)
# Without the following line, the signature will be
# def forward(self, x, y=None, z)
# , which is not valid Python code.
node.args = (None,)
replaced_attrs.append(replaced_attr)
return tuple(replaced_attrs)
@_beartype.beartype
def _trace_into_fx_graph_via_fx_symbolic_trace(
module: torch.nn.Module,
*args,
# kwargs are the keyword arguments to call "module"; that is,
# module(*args, **kwargs) must run.
**kwargs,
) -> Tuple["torch.fx.GraphModule", Tuple[Any, ...]]:
signature = inspect.signature(module.forward)
# We hope the input kwargs will be mapped to bound.args after binding.
# If not, we will raise an error.
bound = signature.bind(*args, **kwargs)
bound.apply_defaults()
# After apply_defaults, all non keyword-only arguments are in bound.args.
# Because below code do not support keyword-word arguments, bound.kwargs
# must be empty.
assert len(bound.kwargs) == 0, bound.kwargs
# Create inputs to call symbolic trace (torch.fx.symbolic_trace)
# Example content of concrete_args:
# concrete_args["x"] = torch.fx._symbolic_trace.PH
# concrete_args["b"] = 1
# where "x" and "b" are argument names in "signature".
concrete_args = {}
for param_name, param_value in bound.arguments.items():
if isinstance(param_value, torch.Tensor):
# param_value can be, e.g., a real tensor or a fake tensor.
# param_value is treated as substitutable tensor symbol (aka placeholder).
concrete_args[param_name] = torch.fx._symbolic_trace.PH
else:
concrete_args[param_name] = param_value
return (
_module_expansion_symbolic_trace(module, concrete_args=concrete_args),
bound.args,
)
@_beartype.beartype
def export_without_parameters_and_buffers(
module: torch.nn.Module,
*args,
decomposition_table: Optional[Dict[torch._ops.OpOverload, Callable]] = None,
use_binary_format: bool = True,
opset_version: int = _constants.ONNX_DEFAULT_OPSET,
op_level_debug: bool = False,
# kwargs are the keyword arguments to call "module"; that is,
# module(*args, **kwargs) must run.
**kwargs,
) -> Tuple[
Union["onnx.ModelProto", bytes],
"torch.fx.GraphModule",
Tuple[Any, ...],
Tuple[Any, ...],
]:
graph_module, bound_args = _trace_into_fx_graph_via_fx_symbolic_trace(
module, *args, **kwargs
)
# Make sure all placeholder nodes are executed before get_attr nodes.
# Otherwise, inputs can interleave with initializers in the final ModeoProto.graph.input.
# Basically, we want
# ModeoProto.graph.input =
# [input_0, input_1, ..., input_n, weight_0, weight_1, ..., weight_m]
# and we don't want
# ModeoProto.graph.input =
# [input_0, weight_0, input_1, weight_1, ..., input_n, weight_0, weight_1, ..., weight_m]
_move_placeholder_to_front(graph_module)
# To save memory, move get_attr to input so that the generated model doesn't
# have weigh tensors. "replaced_attrs" are the list of replaced weight tensors.
replaced_attrs = _replace_get_attr_with_placeholder(graph_module)
# Move all newly created placeholder nodes to the front of the graph.
_move_placeholder_to_front(graph_module)
# Finalize the graph editing.
graph_module.recompile()
return (
_export(
graph_module,
(*bound_args, *replaced_attrs),
opset_version=opset_version,
decomposition_table=decomposition_table,
use_binary_format=use_binary_format,
op_level_debug=op_level_debug,
),
graph_module,
bound_args,
replaced_attrs,
)
@_beartype.beartype
def _create_tensor_proto_with_external_data(
tensor: torch.Tensor, name: str, location: str, basepath: str
) -> "onnx.TensorProto":
"""Create a TensorProto with external data from a PyTorch tensor.
The external data is saved to os.path.join(basepath, location).
Args:
tensor: Tensor to be saved.
name: Name of the tensor (i.e., initializer name in ONNX graph).
location: Relative location of the external data file
(e.g., "/tmp/initializers/weight_0" when model is "/tmp/model_name.onnx").
basepath: Base path of the external data file (e.g., "/tmp/external_data" while model must be in "/tmp").
Reference for ONNX's external data format:
How to load?
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L187
How to save?
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L43
How to set ONNX fields?
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L88
"""
tensor_proto = onnx.TensorProto()
tensor_proto.name = name
tensor_proto.data_type = torch.onnx._type_utils._SCALAR_TYPE_TO_ONNX[ # type: ignore[assignment]
torch.onnx._type_utils._DTYPE_TO_SCALAR_TYPE[tensor.dtype]
]
tensor_proto.dims.extend(tensor.shape)
tensor_proto.data_location = onnx.TensorProto.EXTERNAL
# Settings for saving one tensor per file.
# Offset is zero because there is no other tensor in the same file.
key_value_pairs = {
"location": location,
"offset": 0,
"length": tensor.untyped_storage().nbytes(),
}
for k, v in key_value_pairs.items():
entry = tensor_proto.external_data.add()
entry.key = k
entry.value = str(v)
# Actual path to write content of tensor.
external_data_file_path = os.path.join(basepath, location)
if os.path.exists(external_data_file_path):
os.remove(external_data_file_path)
# Create external data's folder if not exists.
external_data_dir_path = os.path.dirname(external_data_file_path)
if not os.path.exists(external_data_dir_path):
# if the demo_folder directory is not present
# then create it.
os.makedirs(external_data_dir_path)
# Create a fresh file.
with open(external_data_file_path, "xb") as data_file:
# No need to call "seek" because offset is 0.
# data_file.seek(0)
# Write tensor content to the file.
data_file.write(tensor.numpy().tobytes())
return tensor_proto
@_beartype.beartype
def save_model_with_external_data(
basepath: str,
model_location: str,
initializer_location: str,
torch_load_paths: Tuple[str, ...],
onnx_model: "onnx.ModelProto",
) -> None:
"""Load PyTorch tensors from files and add to "onnx_model" as external initializers.
Output files:
ONNX model file path:
ONNX initializer folder: os.path.join(basepath, initializer_location)
After running this function, you can do
ort_sess = onnxruntime.InferenceSession(os.path.join(basepath, model_location))
to execute the model.
Arguments:
basepath: Base path of the external data file (e.g., "/tmp/large-onnx-model").
model_location: Relative location of the ONNX model file.
E.g., "model.onnx" so that the model file is saved to
"/tmp/large-onnx-model/model.onnx".
initializer_location: Relative location of the ONNX initializer folder.
E.g., "initializers" so that the initializers are saved to
"/tmp/large-onnx-model/initializers".
torch_load_paths: Files which containing serialized PyTorch tensors to be saved
as ONNX initializers. They are loaded by torch.load.
onnx_model: ONNX model to be saved with external initializers.
If an input name matches a tensor loaded from "torch_load_paths",
the tensor will be saved as that input's external initializer.
"""
onnx_model_with_initializers = onnx.ModelProto()
onnx_model_with_initializers.CopyFrom(onnx_model)
onnx_input_names = [input.name for input in onnx_model.graph.input]
for path in torch_load_paths:
state_ditc = torch.load(path)
for name, tensor in state_ditc.items():
# Basically, "transformer.attention.self.query.weight" is mapped
# to "transformer_attention_self_query_weight" for mimicking the
# name-modifying code in FX-to-ONNX exporter.
# See function _replace_get_attr_with_placeholder for details.
refined_name = name.replace(".", "_")
# For each refined PyTorch tensor name loaded by torch.load,
# 1. Search its best match in ONNX model. E.g., the match of
# "transformer_attention_weight" could be "attention_weight".
# 2. Set "tensor" as the initializer of the matched ONNX input.
# E.g., "tensor" is stored as the initializer of "attention_weight".
# Step 1 is required because sometimes, tensor names are stored with prefix the dictionary
# loaded by torch.load.
for onnx_input_name in onnx_input_names:
if onnx_input_name.endswith(refined_name) or refined_name.endswith(
onnx_input_name
):
# Find a match. Change refined_name to the matched ONNX input name, so that we
# create initializer with the right ONNX name.
refined_name = onnx_input_name
break
relative_tensor_file_path = os.path.join(initializer_location, refined_name)
# Create one file per tensor.
# tensor_proto.raw_data is stored to external file at
# os.path.join(basepath, relative_tensor_file_path).
tensor_proto = _create_tensor_proto_with_external_data(
tensor, refined_name, relative_tensor_file_path, basepath
)
# Add the tensor_proto to the ONNX model as an initializer with external data.
onnx_model_with_initializers.graph.initializer.append(tensor_proto)
# model_location should be a pure file name such as "file_name.onnx", not "folder/file_name.onnx".
onnx.save(onnx_model_with_initializers, os.path.join(basepath, model_location))
@_beartype.beartype @_beartype.beartype
def _validate_op_between_ort_torch( def _validate_op_between_ort_torch(
node: torch.fx.Node, node: torch.fx.Node,

View File

@ -0,0 +1,149 @@
from __future__ import annotations
import os
from typing import Tuple
import onnx
import torch
from torch.onnx._internal import _beartype
@_beartype.beartype
def _create_tensor_proto_with_external_data(
tensor: torch.Tensor, name: str, location: str, basepath: str
) -> "onnx.TensorProto":
"""Create a TensorProto with external data from a PyTorch tensor.
The external data is saved to os.path.join(basepath, location).
Args:
tensor: Tensor to be saved.
name: Name of the tensor (i.e., initializer name in ONNX graph).
location: Relative location of the external data file
(e.g., "/tmp/initializers/weight_0" when model is "/tmp/model_name.onnx").
basepath: Base path of the external data file (e.g., "/tmp/external_data" while model must be in "/tmp").
Reference for ONNX's external data format:
How to load?
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L187
How to save?
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L43
How to set ONNX fields?
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L88
"""
tensor_proto = onnx.TensorProto()
tensor_proto.name = name
tensor_proto.data_type = torch.onnx._type_utils._SCALAR_TYPE_TO_ONNX[ # type: ignore[assignment]
torch.onnx._type_utils._DTYPE_TO_SCALAR_TYPE[tensor.dtype]
]
tensor_proto.dims.extend(tensor.shape)
tensor_proto.data_location = onnx.TensorProto.EXTERNAL
# Settings for saving one tensor per file.
# Offset is zero because there is no other tensor in the same file.
key_value_pairs = {
"location": location,
"offset": 0,
"length": tensor.untyped_storage().nbytes(),
}
for k, v in key_value_pairs.items():
entry = tensor_proto.external_data.add()
entry.key = k
entry.value = str(v)
# Actual path to write content of tensor.
external_data_file_path = os.path.join(basepath, location)
if os.path.exists(external_data_file_path):
os.remove(external_data_file_path)
# Create external data's folder if not exists.
external_data_dir_path = os.path.dirname(external_data_file_path)
if not os.path.exists(external_data_dir_path):
# if the demo_folder directory is not present
# then create it.
os.makedirs(external_data_dir_path)
# Create a fresh file.
with open(external_data_file_path, "xb") as data_file:
# No need to call "seek" because offset is 0.
# data_file.seek(0)
# Write tensor content to the file.
data_file.write(tensor.numpy().tobytes())
return tensor_proto
@_beartype.beartype
def save_model_with_external_data(
basepath: str,
model_location: str,
initializer_location: str,
torch_load_paths: Tuple[str, ...],
onnx_model: onnx.ModelProto,
) -> None:
"""Load PyTorch tensors from files and add to "onnx_model" as external initializers.
Output files:
ONNX model file path:
ONNX initializer folder: os.path.join(basepath, initializer_location)
After running this function, you can do
ort_sess = onnxruntime.InferenceSession(os.path.join(basepath, model_location))
to execute the model.
Arguments:
basepath: Base path of the external data file (e.g., "/tmp/large-onnx-model").
model_location: Relative location of the ONNX model file.
E.g., "model.onnx" so that the model file is saved to
"/tmp/large-onnx-model/model.onnx".
initializer_location: Relative location of the ONNX initializer folder.
E.g., "initializers" so that the initializers are saved to
"/tmp/large-onnx-model/initializers".
torch_load_paths: Files which containing serialized PyTorch tensors to be saved
as ONNX initializers. They are loaded by torch.load.
onnx_model: ONNX model to be saved with external initializers.
If an input name matches a tensor loaded from "torch_load_paths",
the tensor will be saved as that input's external initializer.
"""
onnx_model_with_initializers = onnx.ModelProto()
onnx_model_with_initializers.CopyFrom(onnx_model)
onnx_input_names = [input.name for input in onnx_model.graph.input]
for path in torch_load_paths:
state_ditc = torch.load(path)
for name, tensor in state_ditc.items():
# Basically, "transformer.attention.self.query.weight" is mapped
# to "transformer_attention_self_query_weight" for mimicking the
# name-modifying code in FX-to-ONNX exporter.
# See function _replace_get_attr_with_placeholder for details.
refined_name = name.replace(".", "_")
# For each refined PyTorch tensor name loaded by torch.load,
# 1. Search its best match in ONNX model. E.g., the match of
# "transformer_attention_weight" could be "attention_weight".
# 2. Set "tensor" as the initializer of the matched ONNX input.
# E.g., "tensor" is stored as the initializer of "attention_weight".
# Step 1 is required because sometimes, tensor names are stored with prefix the dictionary
# loaded by torch.load.
for onnx_input_name in onnx_input_names:
if onnx_input_name.endswith(refined_name) or refined_name.endswith(
onnx_input_name
):
# Find a match. Change refined_name to the matched ONNX input name, so that we
# create initializer with the right ONNX name.
refined_name = onnx_input_name
break
relative_tensor_file_path = os.path.join(initializer_location, refined_name)
# Create one file per tensor.
# tensor_proto.raw_data is stored to external file at
# os.path.join(basepath, relative_tensor_file_path).
tensor_proto = _create_tensor_proto_with_external_data(
tensor, refined_name, relative_tensor_file_path, basepath
)
# Add the tensor_proto to the ONNX model as an initializer with external data.
onnx_model_with_initializers.graph.initializer.append(tensor_proto)
# model_location should be a pure file name such as "file_name.onnx", not "folder/file_name.onnx".
onnx.save(onnx_model_with_initializers, os.path.join(basepath, model_location))

View File

@ -0,0 +1,289 @@
from __future__ import annotations
import functools
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import onnx
import torch
import torch.fx
from torch.onnx import _constants
from torch.onnx._internal import _beartype
from torch.onnx._internal.fx import exporter
# Functions directly wrapped to produce torch.fx.Proxy so that symbolic
# data can flow through those functions. Python functions (e.g., `torch.arange`)
# not defined by pybind11 in C++ do not go though Python dispatcher, so
# they are not automatically patched by FX's Python dispatcher.
# The list below means `torch.arange`, `torch.tensor`, and so on will be
# patched.
_TORCH_METHODS_TO_PATCH: Tuple[str, ...] = (
"arange",
"tensor",
"finfo",
"full",
"empty",
)
class ModuleExpansionTracer(torch.fx._symbolic_trace.Tracer):
"""Tracer to create ONNX-exporting friendly FX graph.
This tracer traces models into operators. That is,
the traced graph mostly contains call_function nodes and
has no call_module nodes. The call_module nodes
are problematic to the use of make_fx(...) in ONNX
exporter.
"""
@_beartype.beartype
def is_leaf_module(
self, module: torch.nn.Module, module_qualified_name: str
) -> bool:
# This returns False so that all sub-modules are considered as not leaves
# and therefore expanded into operators in
# torch.fx._symbolic_trace.Tracer.call_module.
return False
@_beartype.beartype
def to_bool(self, obj: "torch.fx.Proxy") -> bool:
# FIXME: This is a hack to tracing through if-else Python blocks.
# It may generate incorrect ONNX graphs if the if-else block
return False
@_beartype.beartype
def _move_placeholder_to_front(graph_module: torch.fx.GraphModule) -> None:
"""
This function move all placeholder nodes to the front of the graph node list.
In torch.fx.Graph, placeholder is a special assignment node. If it's not
executed in the beginning, it could overwrite values computed by upstream
nodes.
"""
graph = graph_module.graph
placeholders = []
first_not_placeholder = None
for node in graph.nodes:
if node.op == "placeholder":
placeholders.append(node)
if first_not_placeholder is None and node.op != "placeholder":
first_not_placeholder = node
if first_not_placeholder is None:
return
for placeholder in placeholders:
first_not_placeholder.prepend(placeholder)
@_beartype.beartype
def _replace_get_attr_with_placeholder(
graph_module: torch.fx.GraphModule,
) -> Tuple[torch.Tensor, ...]:
"""
Replace get_attr with placeholder.
The parameters and buffers accessed by the original get_attr are returned;
they are useful when creating random inputs for the modified graph_module.
"""
graph = graph_module.graph
replaced_attrs: List[torch.Tensor] = []
for node in graph.nodes:
if node.op == "get_attr":
replaced_attr: Optional[torch.Tensor] = None
# get_attr could retrieve either parameter or buffer, so
# we need to try both.
try:
replaced_attr = graph_module.get_parameter(node.target)
except AttributeError:
# It's possible that model author use buffer instead of
# parameter to store trainable weights. In this case,
# 1. get_parameter will throw something like
# AttributeError: `bias` is not an nn.Parameter.
# 2. get_buffer should work.
replaced_attr = graph_module.get_buffer(node.target)
# Reassign op type so that get_attr node becomes placeholder node.
node.op = "placeholder"
# The target name in placeholder must be a valid Python identifier.
# Thus, we replace, e.g., "module.submodule.weight" with
# "module_submodule_weight".
node.target = node.target.replace(".", "_")
# Default value is None. This is needed as long as the "graph_module"
# has optional inputs. Assume the original forward signature is
# def forward(self, x, y=None)
# and the replaced get_attr node has target "z". Then, the modified
# signature should be
# def forward(self, x, y=None, z=None)
# Without the following line, the signature will be
# def forward(self, x, y=None, z)
# , which is not valid Python code.
node.args = (None,)
replaced_attrs.append(replaced_attr)
return tuple(replaced_attrs)
@_beartype.beartype
def _trace_into_fx_graph_via_fx_symbolic_trace(
module: torch.nn.Module,
*args,
# kwargs are the keyword arguments to call "module"; that is,
# module(*args, **kwargs) must run.
**kwargs,
) -> Tuple["torch.fx.GraphModule", Tuple[Any, ...]]:
signature = inspect.signature(module.forward)
# We hope the input kwargs will be mapped to bound.args after binding.
# If not, we will raise an error.
bound = signature.bind(*args, **kwargs)
bound.apply_defaults()
# After apply_defaults, all non keyword-only arguments are in bound.args.
# Because below code do not support keyword-word arguments, bound.kwargs
# must be empty.
assert len(bound.kwargs) == 0, bound.kwargs
# Create inputs to call symbolic trace (torch.fx.symbolic_trace)
# Example content of concrete_args:
# concrete_args["x"] = torch.fx._symbolic_trace.PH
# concrete_args["b"] = 1
# where "x" and "b" are argument names in "signature".
concrete_args = {}
for param_name, param_value in bound.arguments.items():
if isinstance(param_value, torch.Tensor):
# param_value can be, e.g., a real tensor or a fake tensor.
# param_value is treated as substitutable tensor symbol (aka placeholder).
concrete_args[param_name] = torch.fx._symbolic_trace.PH
else:
concrete_args[param_name] = param_value
return (
_module_expansion_symbolic_trace(module, concrete_args=concrete_args),
bound.args,
)
def _wrap_for_symbolic_trace(target: Callable) -> Tuple[Callable, Callable]:
"""This function wraps ```target`` for symbolic tracing.
This function wraps ```target``` so that its wrapper produces
torch.fx.Proxy in symbolic computation. The returned values are
the wrapper and then the original function. Per `_TORCH_METHODS_TO_PATCH`,
this function shall receive `torch.arange`, `torch.tensor`, etc. as inputs.
"""
@functools.wraps(target)
def wrapper(*args, **kwargs):
proxy = None
def check_has_proxy(v):
if isinstance(v, torch.fx.Proxy):
nonlocal proxy
proxy = v
torch.fx.node.map_aggregate(args, check_has_proxy)
torch.fx.node.map_aggregate(kwargs, check_has_proxy)
if proxy is not None:
return proxy.tracer.create_proxy("call_function", target, args, kwargs)
else:
return target(*args, **kwargs)
return wrapper, target
@_beartype.beartype
def _module_expansion_symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
) -> torch.fx.GraphModule:
"""Trace a callable into FX graph.
When "root" is torch.nn.Module, calls to its submodule (type: torch.nn.Module) will be
expanded into operators (e.g., torch.matmul, torch.add, +, and -) to simplify graph
structure.
"""
# For functions doesn't support symbolic tracing, create wrappers
# which produce symbolic results during tracing.
patched_torch_methods = {
target_name: _wrap_for_symbolic_trace(getattr(torch, target_name))
for target_name in _TORCH_METHODS_TO_PATCH
}
# Set the symbolic-tracing friendly functions so that `tracer.trace` below
# can work.
for name, (wrapper, _) in patched_torch_methods.items():
setattr(torch, name, wrapper)
try:
# Set up a tracer.
tracer = ModuleExpansionTracer()
# Trace the model.
graph = tracer.trace(root, concrete_args)
name = (
root.__class__.__name__
if isinstance(root, torch.nn.Module)
else root.__name__
)
return torch.fx.GraphModule(tracer.root, graph, name)
finally:
# Revert the patches for symbolic tracing.
for name, (_, wrapped) in patched_torch_methods.items():
# wrapped is the original version of `torch.name`.
setattr(torch, name, wrapped)
@_beartype.beartype
def export_without_parameters_and_buffers(
module: torch.nn.Module,
*args,
decomposition_table: Optional[Dict[torch._ops.OpOverload, Callable]] = None,
use_binary_format: bool = True,
opset_version: int = _constants.ONNX_DEFAULT_OPSET,
op_level_debug: bool = False,
# kwargs are the keyword arguments to call "module"; that is,
# module(*args, **kwargs) must run.
**kwargs,
) -> Tuple[
Union[onnx.ModelProto, bytes],
torch.fx.GraphModule,
Tuple[Any, ...],
Tuple[Any, ...],
]:
graph_module, bound_args = _trace_into_fx_graph_via_fx_symbolic_trace(
module, *args, **kwargs
)
# Make sure all placeholder nodes are executed before get_attr nodes.
# Otherwise, inputs can interleave with initializers in the final ModeoProto.graph.input.
# Basically, we want
# ModeoProto.graph.input =
# [input_0, input_1, ..., input_n, weight_0, weight_1, ..., weight_m]
# and we don't want
# ModeoProto.graph.input =
# [input_0, weight_0, input_1, weight_1, ..., input_n, weight_0, weight_1, ..., weight_m]
_move_placeholder_to_front(graph_module)
# To save memory, move get_attr to input so that the generated model doesn't
# have weigh tensors. "replaced_attrs" are the list of replaced weight tensors.
replaced_attrs = _replace_get_attr_with_placeholder(graph_module)
# Move all newly created placeholder nodes to the front of the graph.
_move_placeholder_to_front(graph_module)
# Finalize the graph editing.
graph_module.recompile()
return (
exporter._export(
graph_module,
(*bound_args, *replaced_attrs),
opset_version=opset_version,
decomposition_table=decomposition_table,
use_binary_format=use_binary_format,
op_level_debug=op_level_debug,
),
graph_module,
bound_args,
replaced_attrs,
)