[ONNX] Move symbolic export to separate file (#95650)

Move things around in the effort of preparing to refactor
the code structure.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95650
Approved by: https://github.com/titaiwangms
This commit is contained in:
BowenBao 2023-03-07 10:31:37 -08:00 committed by PyTorch MergeBot
parent d06729746c
commit 82dba844bb
4 changed files with 441 additions and 422 deletions

View File

@ -1,10 +1,7 @@
from .context import FxToOnnxContext
from .exporter import (
export,
export_after_normalizing_args_and_kwargs,
export_without_parameters_and_buffers,
save_model_with_external_data,
)
from .exporter import export, export_after_normalizing_args_and_kwargs
from .serialization import save_model_with_external_data
from .symbolic_exporter import export_without_parameters_and_buffers
__all__ = [

View File

@ -1,11 +1,9 @@
from __future__ import annotations
import copy
import functools
import inspect
import itertools
import operator
import os
import re
import warnings
from types import FunctionType
@ -75,117 +73,6 @@ onnxscript.OnnxFunction.__call__ = _diagnose_onnx_function(
)
class ModuleExpansionTracer(torch.fx._symbolic_trace.Tracer):
"""Tracer to create ONNX-exporting friendly FX graph.
This tracer traces models into operators. That is,
the traced graph mostly contains call_function nodes and
has no call_module nodes. The call_module nodes
are problematic to the use of make_fx(...) in ONNX
exporter.
"""
@_beartype.beartype
def is_leaf_module(
self, module: torch.nn.Module, module_qualified_name: str
) -> bool:
# This returns False so that all sub-modules are considered as not leaves
# and therefore expanded into operators in
# torch.fx._symbolic_trace.Tracer.call_module.
return False
@_beartype.beartype
def to_bool(self, obj: "torch.fx.Proxy") -> bool:
# This is a hack to tracing through if-else Python blocks.
# It may generate incorrect ONNX graphs if the if-else block
return False
# Functions directly wrapped to produce torch.fx.Proxy so that symbolic
# data can flow through those functions. Python functions (e.g., `torch.arange`)
# not defined by pybind11 in C++ do not go though Python dispatcher, so
# they are not automatically patched by FX's Python dispatcher.
# The list below means `torch.arange`, `torch.tensor`, and so on will be
# patched.
_TORCH_METHODS_TO_PATCH: Tuple[str, ...] = (
"arange",
"tensor",
"finfo",
"full",
"empty",
)
def _wrap_for_symbolic_trace(target: Callable) -> Tuple[Callable, Callable]:
"""This function wraps ```target`` for symbolic tracing.
This function wraps ```target``` so that its wrapper produces
torch.fx.Proxy in symbolic computation. The returned values are
the wrapper and then the original function. Per `_TORCH_METHODS_TO_PATCH`,
this function shall receive `torch.arange`, `torch.tensor`, etc. as inputs.
"""
@functools.wraps(target)
def wrapper(*args, **kwargs):
proxy = None
def check_has_proxy(v):
if isinstance(v, torch.fx.Proxy):
nonlocal proxy
proxy = v
torch.fx.node.map_aggregate(args, check_has_proxy)
torch.fx.node.map_aggregate(kwargs, check_has_proxy)
if proxy is not None:
return proxy.tracer.create_proxy("call_function", target, args, kwargs)
else:
return target(*args, **kwargs)
return wrapper, target
@_beartype.beartype
def _module_expansion_symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
) -> "torch.fx.GraphModule":
"""Trace a callable into FX graph.
When "root" is torch.nn.Module, calls to its submodule (type: torch.nn.Module) will be
expanded into operators (e.g., torch.matmul, torch.add, +, and -) to simplify graph
structure.
"""
# For functions doesn't support symbolic tracing, create wrappers
# which produce symbolic results during tracing.
patched_torch_methods = {
target_name: _wrap_for_symbolic_trace(getattr(torch, target_name))
for target_name in _TORCH_METHODS_TO_PATCH
}
# Set the symbolic-tracing friendly functions so that `tracer.trace` below
# can work.
for name, (wrapper, _) in patched_torch_methods.items():
setattr(torch, name, wrapper)
try:
# Set up a tracer.
tracer = ModuleExpansionTracer()
# Trace the model.
graph = tracer.trace(root, concrete_args)
name = (
root.__class__.__name__
if isinstance(root, torch.nn.Module)
else root.__name__
)
return torch.fx.GraphModule(tracer.root, graph, name)
finally:
# Revert the patches for symbolic tracing.
for name, (_, wrapped) in patched_torch_methods.items():
# wrapped is the original version of `torch.name`.
setattr(torch, name, wrapped)
def _retrieve_or_adapt_input_to_graph_set(fx_node_arg, fx_name_to_onnxscipt_value):
"""Map FX value to TorchScript value.
@ -780,309 +667,6 @@ def export_after_normalizing_args_and_kwargs(
)
@_beartype.beartype
def _move_placeholder_to_front(graph_module: "torch.fx.GraphModule") -> None:
"""
This function move all placeholder nodes to the front of the graph node list.
In torch.fx.Graph, placeholder is a special assignment node. If it's not
executed in the beginning, it could overwrite values computed by upstream
nodes.
"""
graph = graph_module.graph
placeholders = []
first_not_placeholder = None
for node in graph.nodes:
if node.op == "placeholder":
placeholders.append(node)
if first_not_placeholder is None and node.op != "placeholder":
first_not_placeholder = node
if first_not_placeholder is None:
return
for placeholder in placeholders:
first_not_placeholder.prepend(placeholder)
@_beartype.beartype
def _replace_get_attr_with_placeholder(
graph_module: "torch.fx.GraphModule",
) -> Tuple[torch.Tensor, ...]:
"""
Replace get_attr with placeholder.
The parameters and buffers accessed by the original get_attr are returned;
they are useful when creating random inputs for the modified graph_module.
"""
graph = graph_module.graph
replaced_attrs: List[torch.Tensor] = []
for node in graph.nodes:
if node.op == "get_attr":
replaced_attr: Optional[torch.Tensor] = None
# get_attr could retrieve either parameter or buffer, so
# we need to try both.
try:
replaced_attr = graph_module.get_parameter(node.target)
except AttributeError:
# It's possible that model author use buffer instead of
# parameter to store trainable weights. In this case,
# 1. get_parameter will throw something like
# AttributeError: `bias` is not an nn.Parameter.
# 2. get_buffer should work.
replaced_attr = graph_module.get_buffer(node.target)
# Reassign op type so that get_attr node becomes placeholder node.
node.op = "placeholder"
# The target name in placeholder must be a valid Python identifier.
# Thus, we replace, e.g., "module.submodule.weight" with
# "module_submodule_weight".
node.target = node.target.replace(".", "_")
# Default value is None. This is needed as long as the "graph_module"
# has optional inputs. Assume the original forward signature is
# def forward(self, x, y=None)
# and the replaced get_attr node has target "z". Then, the modified
# signature should be
# def forward(self, x, y=None, z=None)
# Without the following line, the signature will be
# def forward(self, x, y=None, z)
# , which is not valid Python code.
node.args = (None,)
replaced_attrs.append(replaced_attr)
return tuple(replaced_attrs)
@_beartype.beartype
def _trace_into_fx_graph_via_fx_symbolic_trace(
module: torch.nn.Module,
*args,
# kwargs are the keyword arguments to call "module"; that is,
# module(*args, **kwargs) must run.
**kwargs,
) -> Tuple["torch.fx.GraphModule", Tuple[Any, ...]]:
signature = inspect.signature(module.forward)
# We hope the input kwargs will be mapped to bound.args after binding.
# If not, we will raise an error.
bound = signature.bind(*args, **kwargs)
bound.apply_defaults()
# After apply_defaults, all non keyword-only arguments are in bound.args.
# Because below code do not support keyword-word arguments, bound.kwargs
# must be empty.
assert len(bound.kwargs) == 0, bound.kwargs
# Create inputs to call symbolic trace (torch.fx.symbolic_trace)
# Example content of concrete_args:
# concrete_args["x"] = torch.fx._symbolic_trace.PH
# concrete_args["b"] = 1
# where "x" and "b" are argument names in "signature".
concrete_args = {}
for param_name, param_value in bound.arguments.items():
if isinstance(param_value, torch.Tensor):
# param_value can be, e.g., a real tensor or a fake tensor.
# param_value is treated as substitutable tensor symbol (aka placeholder).
concrete_args[param_name] = torch.fx._symbolic_trace.PH
else:
concrete_args[param_name] = param_value
return (
_module_expansion_symbolic_trace(module, concrete_args=concrete_args),
bound.args,
)
@_beartype.beartype
def export_without_parameters_and_buffers(
module: torch.nn.Module,
*args,
decomposition_table: Optional[Dict[torch._ops.OpOverload, Callable]] = None,
use_binary_format: bool = True,
opset_version: int = _constants.ONNX_DEFAULT_OPSET,
op_level_debug: bool = False,
# kwargs are the keyword arguments to call "module"; that is,
# module(*args, **kwargs) must run.
**kwargs,
) -> Tuple[
Union["onnx.ModelProto", bytes],
"torch.fx.GraphModule",
Tuple[Any, ...],
Tuple[Any, ...],
]:
graph_module, bound_args = _trace_into_fx_graph_via_fx_symbolic_trace(
module, *args, **kwargs
)
# Make sure all placeholder nodes are executed before get_attr nodes.
# Otherwise, inputs can interleave with initializers in the final ModeoProto.graph.input.
# Basically, we want
# ModeoProto.graph.input =
# [input_0, input_1, ..., input_n, weight_0, weight_1, ..., weight_m]
# and we don't want
# ModeoProto.graph.input =
# [input_0, weight_0, input_1, weight_1, ..., input_n, weight_0, weight_1, ..., weight_m]
_move_placeholder_to_front(graph_module)
# To save memory, move get_attr to input so that the generated model doesn't
# have weigh tensors. "replaced_attrs" are the list of replaced weight tensors.
replaced_attrs = _replace_get_attr_with_placeholder(graph_module)
# Move all newly created placeholder nodes to the front of the graph.
_move_placeholder_to_front(graph_module)
# Finalize the graph editing.
graph_module.recompile()
return (
_export(
graph_module,
(*bound_args, *replaced_attrs),
opset_version=opset_version,
decomposition_table=decomposition_table,
use_binary_format=use_binary_format,
op_level_debug=op_level_debug,
),
graph_module,
bound_args,
replaced_attrs,
)
@_beartype.beartype
def _create_tensor_proto_with_external_data(
tensor: torch.Tensor, name: str, location: str, basepath: str
) -> "onnx.TensorProto":
"""Create a TensorProto with external data from a PyTorch tensor.
The external data is saved to os.path.join(basepath, location).
Args:
tensor: Tensor to be saved.
name: Name of the tensor (i.e., initializer name in ONNX graph).
location: Relative location of the external data file
(e.g., "/tmp/initializers/weight_0" when model is "/tmp/model_name.onnx").
basepath: Base path of the external data file (e.g., "/tmp/external_data" while model must be in "/tmp").
Reference for ONNX's external data format:
How to load?
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L187
How to save?
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L43
How to set ONNX fields?
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L88
"""
tensor_proto = onnx.TensorProto()
tensor_proto.name = name
tensor_proto.data_type = torch.onnx._type_utils._SCALAR_TYPE_TO_ONNX[ # type: ignore[assignment]
torch.onnx._type_utils._DTYPE_TO_SCALAR_TYPE[tensor.dtype]
]
tensor_proto.dims.extend(tensor.shape)
tensor_proto.data_location = onnx.TensorProto.EXTERNAL
# Settings for saving one tensor per file.
# Offset is zero because there is no other tensor in the same file.
key_value_pairs = {
"location": location,
"offset": 0,
"length": tensor.untyped_storage().nbytes(),
}
for k, v in key_value_pairs.items():
entry = tensor_proto.external_data.add()
entry.key = k
entry.value = str(v)
# Actual path to write content of tensor.
external_data_file_path = os.path.join(basepath, location)
if os.path.exists(external_data_file_path):
os.remove(external_data_file_path)
# Create external data's folder if not exists.
external_data_dir_path = os.path.dirname(external_data_file_path)
if not os.path.exists(external_data_dir_path):
# if the demo_folder directory is not present
# then create it.
os.makedirs(external_data_dir_path)
# Create a fresh file.
with open(external_data_file_path, "xb") as data_file:
# No need to call "seek" because offset is 0.
# data_file.seek(0)
# Write tensor content to the file.
data_file.write(tensor.numpy().tobytes())
return tensor_proto
@_beartype.beartype
def save_model_with_external_data(
basepath: str,
model_location: str,
initializer_location: str,
torch_load_paths: Tuple[str, ...],
onnx_model: "onnx.ModelProto",
) -> None:
"""Load PyTorch tensors from files and add to "onnx_model" as external initializers.
Output files:
ONNX model file path:
ONNX initializer folder: os.path.join(basepath, initializer_location)
After running this function, you can do
ort_sess = onnxruntime.InferenceSession(os.path.join(basepath, model_location))
to execute the model.
Arguments:
basepath: Base path of the external data file (e.g., "/tmp/large-onnx-model").
model_location: Relative location of the ONNX model file.
E.g., "model.onnx" so that the model file is saved to
"/tmp/large-onnx-model/model.onnx".
initializer_location: Relative location of the ONNX initializer folder.
E.g., "initializers" so that the initializers are saved to
"/tmp/large-onnx-model/initializers".
torch_load_paths: Files which containing serialized PyTorch tensors to be saved
as ONNX initializers. They are loaded by torch.load.
onnx_model: ONNX model to be saved with external initializers.
If an input name matches a tensor loaded from "torch_load_paths",
the tensor will be saved as that input's external initializer.
"""
onnx_model_with_initializers = onnx.ModelProto()
onnx_model_with_initializers.CopyFrom(onnx_model)
onnx_input_names = [input.name for input in onnx_model.graph.input]
for path in torch_load_paths:
state_ditc = torch.load(path)
for name, tensor in state_ditc.items():
# Basically, "transformer.attention.self.query.weight" is mapped
# to "transformer_attention_self_query_weight" for mimicking the
# name-modifying code in FX-to-ONNX exporter.
# See function _replace_get_attr_with_placeholder for details.
refined_name = name.replace(".", "_")
# For each refined PyTorch tensor name loaded by torch.load,
# 1. Search its best match in ONNX model. E.g., the match of
# "transformer_attention_weight" could be "attention_weight".
# 2. Set "tensor" as the initializer of the matched ONNX input.
# E.g., "tensor" is stored as the initializer of "attention_weight".
# Step 1 is required because sometimes, tensor names are stored with prefix the dictionary
# loaded by torch.load.
for onnx_input_name in onnx_input_names:
if onnx_input_name.endswith(refined_name) or refined_name.endswith(
onnx_input_name
):
# Find a match. Change refined_name to the matched ONNX input name, so that we
# create initializer with the right ONNX name.
refined_name = onnx_input_name
break
relative_tensor_file_path = os.path.join(initializer_location, refined_name)
# Create one file per tensor.
# tensor_proto.raw_data is stored to external file at
# os.path.join(basepath, relative_tensor_file_path).
tensor_proto = _create_tensor_proto_with_external_data(
tensor, refined_name, relative_tensor_file_path, basepath
)
# Add the tensor_proto to the ONNX model as an initializer with external data.
onnx_model_with_initializers.graph.initializer.append(tensor_proto)
# model_location should be a pure file name such as "file_name.onnx", not "folder/file_name.onnx".
onnx.save(onnx_model_with_initializers, os.path.join(basepath, model_location))
@_beartype.beartype
def _validate_op_between_ort_torch(
node: torch.fx.Node,

View File

@ -0,0 +1,149 @@
from __future__ import annotations
import os
from typing import Tuple
import onnx
import torch
from torch.onnx._internal import _beartype
@_beartype.beartype
def _create_tensor_proto_with_external_data(
tensor: torch.Tensor, name: str, location: str, basepath: str
) -> "onnx.TensorProto":
"""Create a TensorProto with external data from a PyTorch tensor.
The external data is saved to os.path.join(basepath, location).
Args:
tensor: Tensor to be saved.
name: Name of the tensor (i.e., initializer name in ONNX graph).
location: Relative location of the external data file
(e.g., "/tmp/initializers/weight_0" when model is "/tmp/model_name.onnx").
basepath: Base path of the external data file (e.g., "/tmp/external_data" while model must be in "/tmp").
Reference for ONNX's external data format:
How to load?
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L187
How to save?
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L43
How to set ONNX fields?
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L88
"""
tensor_proto = onnx.TensorProto()
tensor_proto.name = name
tensor_proto.data_type = torch.onnx._type_utils._SCALAR_TYPE_TO_ONNX[ # type: ignore[assignment]
torch.onnx._type_utils._DTYPE_TO_SCALAR_TYPE[tensor.dtype]
]
tensor_proto.dims.extend(tensor.shape)
tensor_proto.data_location = onnx.TensorProto.EXTERNAL
# Settings for saving one tensor per file.
# Offset is zero because there is no other tensor in the same file.
key_value_pairs = {
"location": location,
"offset": 0,
"length": tensor.untyped_storage().nbytes(),
}
for k, v in key_value_pairs.items():
entry = tensor_proto.external_data.add()
entry.key = k
entry.value = str(v)
# Actual path to write content of tensor.
external_data_file_path = os.path.join(basepath, location)
if os.path.exists(external_data_file_path):
os.remove(external_data_file_path)
# Create external data's folder if not exists.
external_data_dir_path = os.path.dirname(external_data_file_path)
if not os.path.exists(external_data_dir_path):
# if the demo_folder directory is not present
# then create it.
os.makedirs(external_data_dir_path)
# Create a fresh file.
with open(external_data_file_path, "xb") as data_file:
# No need to call "seek" because offset is 0.
# data_file.seek(0)
# Write tensor content to the file.
data_file.write(tensor.numpy().tobytes())
return tensor_proto
@_beartype.beartype
def save_model_with_external_data(
basepath: str,
model_location: str,
initializer_location: str,
torch_load_paths: Tuple[str, ...],
onnx_model: onnx.ModelProto,
) -> None:
"""Load PyTorch tensors from files and add to "onnx_model" as external initializers.
Output files:
ONNX model file path:
ONNX initializer folder: os.path.join(basepath, initializer_location)
After running this function, you can do
ort_sess = onnxruntime.InferenceSession(os.path.join(basepath, model_location))
to execute the model.
Arguments:
basepath: Base path of the external data file (e.g., "/tmp/large-onnx-model").
model_location: Relative location of the ONNX model file.
E.g., "model.onnx" so that the model file is saved to
"/tmp/large-onnx-model/model.onnx".
initializer_location: Relative location of the ONNX initializer folder.
E.g., "initializers" so that the initializers are saved to
"/tmp/large-onnx-model/initializers".
torch_load_paths: Files which containing serialized PyTorch tensors to be saved
as ONNX initializers. They are loaded by torch.load.
onnx_model: ONNX model to be saved with external initializers.
If an input name matches a tensor loaded from "torch_load_paths",
the tensor will be saved as that input's external initializer.
"""
onnx_model_with_initializers = onnx.ModelProto()
onnx_model_with_initializers.CopyFrom(onnx_model)
onnx_input_names = [input.name for input in onnx_model.graph.input]
for path in torch_load_paths:
state_ditc = torch.load(path)
for name, tensor in state_ditc.items():
# Basically, "transformer.attention.self.query.weight" is mapped
# to "transformer_attention_self_query_weight" for mimicking the
# name-modifying code in FX-to-ONNX exporter.
# See function _replace_get_attr_with_placeholder for details.
refined_name = name.replace(".", "_")
# For each refined PyTorch tensor name loaded by torch.load,
# 1. Search its best match in ONNX model. E.g., the match of
# "transformer_attention_weight" could be "attention_weight".
# 2. Set "tensor" as the initializer of the matched ONNX input.
# E.g., "tensor" is stored as the initializer of "attention_weight".
# Step 1 is required because sometimes, tensor names are stored with prefix the dictionary
# loaded by torch.load.
for onnx_input_name in onnx_input_names:
if onnx_input_name.endswith(refined_name) or refined_name.endswith(
onnx_input_name
):
# Find a match. Change refined_name to the matched ONNX input name, so that we
# create initializer with the right ONNX name.
refined_name = onnx_input_name
break
relative_tensor_file_path = os.path.join(initializer_location, refined_name)
# Create one file per tensor.
# tensor_proto.raw_data is stored to external file at
# os.path.join(basepath, relative_tensor_file_path).
tensor_proto = _create_tensor_proto_with_external_data(
tensor, refined_name, relative_tensor_file_path, basepath
)
# Add the tensor_proto to the ONNX model as an initializer with external data.
onnx_model_with_initializers.graph.initializer.append(tensor_proto)
# model_location should be a pure file name such as "file_name.onnx", not "folder/file_name.onnx".
onnx.save(onnx_model_with_initializers, os.path.join(basepath, model_location))

View File

@ -0,0 +1,289 @@
from __future__ import annotations
import functools
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import onnx
import torch
import torch.fx
from torch.onnx import _constants
from torch.onnx._internal import _beartype
from torch.onnx._internal.fx import exporter
# Functions directly wrapped to produce torch.fx.Proxy so that symbolic
# data can flow through those functions. Python functions (e.g., `torch.arange`)
# not defined by pybind11 in C++ do not go though Python dispatcher, so
# they are not automatically patched by FX's Python dispatcher.
# The list below means `torch.arange`, `torch.tensor`, and so on will be
# patched.
_TORCH_METHODS_TO_PATCH: Tuple[str, ...] = (
"arange",
"tensor",
"finfo",
"full",
"empty",
)
class ModuleExpansionTracer(torch.fx._symbolic_trace.Tracer):
"""Tracer to create ONNX-exporting friendly FX graph.
This tracer traces models into operators. That is,
the traced graph mostly contains call_function nodes and
has no call_module nodes. The call_module nodes
are problematic to the use of make_fx(...) in ONNX
exporter.
"""
@_beartype.beartype
def is_leaf_module(
self, module: torch.nn.Module, module_qualified_name: str
) -> bool:
# This returns False so that all sub-modules are considered as not leaves
# and therefore expanded into operators in
# torch.fx._symbolic_trace.Tracer.call_module.
return False
@_beartype.beartype
def to_bool(self, obj: "torch.fx.Proxy") -> bool:
# FIXME: This is a hack to tracing through if-else Python blocks.
# It may generate incorrect ONNX graphs if the if-else block
return False
@_beartype.beartype
def _move_placeholder_to_front(graph_module: torch.fx.GraphModule) -> None:
"""
This function move all placeholder nodes to the front of the graph node list.
In torch.fx.Graph, placeholder is a special assignment node. If it's not
executed in the beginning, it could overwrite values computed by upstream
nodes.
"""
graph = graph_module.graph
placeholders = []
first_not_placeholder = None
for node in graph.nodes:
if node.op == "placeholder":
placeholders.append(node)
if first_not_placeholder is None and node.op != "placeholder":
first_not_placeholder = node
if first_not_placeholder is None:
return
for placeholder in placeholders:
first_not_placeholder.prepend(placeholder)
@_beartype.beartype
def _replace_get_attr_with_placeholder(
graph_module: torch.fx.GraphModule,
) -> Tuple[torch.Tensor, ...]:
"""
Replace get_attr with placeholder.
The parameters and buffers accessed by the original get_attr are returned;
they are useful when creating random inputs for the modified graph_module.
"""
graph = graph_module.graph
replaced_attrs: List[torch.Tensor] = []
for node in graph.nodes:
if node.op == "get_attr":
replaced_attr: Optional[torch.Tensor] = None
# get_attr could retrieve either parameter or buffer, so
# we need to try both.
try:
replaced_attr = graph_module.get_parameter(node.target)
except AttributeError:
# It's possible that model author use buffer instead of
# parameter to store trainable weights. In this case,
# 1. get_parameter will throw something like
# AttributeError: `bias` is not an nn.Parameter.
# 2. get_buffer should work.
replaced_attr = graph_module.get_buffer(node.target)
# Reassign op type so that get_attr node becomes placeholder node.
node.op = "placeholder"
# The target name in placeholder must be a valid Python identifier.
# Thus, we replace, e.g., "module.submodule.weight" with
# "module_submodule_weight".
node.target = node.target.replace(".", "_")
# Default value is None. This is needed as long as the "graph_module"
# has optional inputs. Assume the original forward signature is
# def forward(self, x, y=None)
# and the replaced get_attr node has target "z". Then, the modified
# signature should be
# def forward(self, x, y=None, z=None)
# Without the following line, the signature will be
# def forward(self, x, y=None, z)
# , which is not valid Python code.
node.args = (None,)
replaced_attrs.append(replaced_attr)
return tuple(replaced_attrs)
@_beartype.beartype
def _trace_into_fx_graph_via_fx_symbolic_trace(
module: torch.nn.Module,
*args,
# kwargs are the keyword arguments to call "module"; that is,
# module(*args, **kwargs) must run.
**kwargs,
) -> Tuple["torch.fx.GraphModule", Tuple[Any, ...]]:
signature = inspect.signature(module.forward)
# We hope the input kwargs will be mapped to bound.args after binding.
# If not, we will raise an error.
bound = signature.bind(*args, **kwargs)
bound.apply_defaults()
# After apply_defaults, all non keyword-only arguments are in bound.args.
# Because below code do not support keyword-word arguments, bound.kwargs
# must be empty.
assert len(bound.kwargs) == 0, bound.kwargs
# Create inputs to call symbolic trace (torch.fx.symbolic_trace)
# Example content of concrete_args:
# concrete_args["x"] = torch.fx._symbolic_trace.PH
# concrete_args["b"] = 1
# where "x" and "b" are argument names in "signature".
concrete_args = {}
for param_name, param_value in bound.arguments.items():
if isinstance(param_value, torch.Tensor):
# param_value can be, e.g., a real tensor or a fake tensor.
# param_value is treated as substitutable tensor symbol (aka placeholder).
concrete_args[param_name] = torch.fx._symbolic_trace.PH
else:
concrete_args[param_name] = param_value
return (
_module_expansion_symbolic_trace(module, concrete_args=concrete_args),
bound.args,
)
def _wrap_for_symbolic_trace(target: Callable) -> Tuple[Callable, Callable]:
"""This function wraps ```target`` for symbolic tracing.
This function wraps ```target``` so that its wrapper produces
torch.fx.Proxy in symbolic computation. The returned values are
the wrapper and then the original function. Per `_TORCH_METHODS_TO_PATCH`,
this function shall receive `torch.arange`, `torch.tensor`, etc. as inputs.
"""
@functools.wraps(target)
def wrapper(*args, **kwargs):
proxy = None
def check_has_proxy(v):
if isinstance(v, torch.fx.Proxy):
nonlocal proxy
proxy = v
torch.fx.node.map_aggregate(args, check_has_proxy)
torch.fx.node.map_aggregate(kwargs, check_has_proxy)
if proxy is not None:
return proxy.tracer.create_proxy("call_function", target, args, kwargs)
else:
return target(*args, **kwargs)
return wrapper, target
@_beartype.beartype
def _module_expansion_symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
) -> torch.fx.GraphModule:
"""Trace a callable into FX graph.
When "root" is torch.nn.Module, calls to its submodule (type: torch.nn.Module) will be
expanded into operators (e.g., torch.matmul, torch.add, +, and -) to simplify graph
structure.
"""
# For functions doesn't support symbolic tracing, create wrappers
# which produce symbolic results during tracing.
patched_torch_methods = {
target_name: _wrap_for_symbolic_trace(getattr(torch, target_name))
for target_name in _TORCH_METHODS_TO_PATCH
}
# Set the symbolic-tracing friendly functions so that `tracer.trace` below
# can work.
for name, (wrapper, _) in patched_torch_methods.items():
setattr(torch, name, wrapper)
try:
# Set up a tracer.
tracer = ModuleExpansionTracer()
# Trace the model.
graph = tracer.trace(root, concrete_args)
name = (
root.__class__.__name__
if isinstance(root, torch.nn.Module)
else root.__name__
)
return torch.fx.GraphModule(tracer.root, graph, name)
finally:
# Revert the patches for symbolic tracing.
for name, (_, wrapped) in patched_torch_methods.items():
# wrapped is the original version of `torch.name`.
setattr(torch, name, wrapped)
@_beartype.beartype
def export_without_parameters_and_buffers(
module: torch.nn.Module,
*args,
decomposition_table: Optional[Dict[torch._ops.OpOverload, Callable]] = None,
use_binary_format: bool = True,
opset_version: int = _constants.ONNX_DEFAULT_OPSET,
op_level_debug: bool = False,
# kwargs are the keyword arguments to call "module"; that is,
# module(*args, **kwargs) must run.
**kwargs,
) -> Tuple[
Union[onnx.ModelProto, bytes],
torch.fx.GraphModule,
Tuple[Any, ...],
Tuple[Any, ...],
]:
graph_module, bound_args = _trace_into_fx_graph_via_fx_symbolic_trace(
module, *args, **kwargs
)
# Make sure all placeholder nodes are executed before get_attr nodes.
# Otherwise, inputs can interleave with initializers in the final ModeoProto.graph.input.
# Basically, we want
# ModeoProto.graph.input =
# [input_0, input_1, ..., input_n, weight_0, weight_1, ..., weight_m]
# and we don't want
# ModeoProto.graph.input =
# [input_0, weight_0, input_1, weight_1, ..., input_n, weight_0, weight_1, ..., weight_m]
_move_placeholder_to_front(graph_module)
# To save memory, move get_attr to input so that the generated model doesn't
# have weigh tensors. "replaced_attrs" are the list of replaced weight tensors.
replaced_attrs = _replace_get_attr_with_placeholder(graph_module)
# Move all newly created placeholder nodes to the front of the graph.
_move_placeholder_to_front(graph_module)
# Finalize the graph editing.
graph_module.recompile()
return (
exporter._export(
graph_module,
(*bound_args, *replaced_attrs),
opset_version=opset_version,
decomposition_table=decomposition_table,
use_binary_format=use_binary_format,
op_level_debug=op_level_debug,
),
graph_module,
bound_args,
replaced_attrs,
)