diff --git a/torch/onnx/_internal/fx/__init__.py b/torch/onnx/_internal/fx/__init__.py index 57fbf56c528..a9f379e2fd2 100644 --- a/torch/onnx/_internal/fx/__init__.py +++ b/torch/onnx/_internal/fx/__init__.py @@ -1,10 +1,7 @@ from .context import FxToOnnxContext -from .exporter import ( - export, - export_after_normalizing_args_and_kwargs, - export_without_parameters_and_buffers, - save_model_with_external_data, -) +from .exporter import export, export_after_normalizing_args_and_kwargs +from .serialization import save_model_with_external_data +from .symbolic_exporter import export_without_parameters_and_buffers __all__ = [ diff --git a/torch/onnx/_internal/fx/exporter.py b/torch/onnx/_internal/fx/exporter.py index 1d18cb8ab07..e6193cdf501 100644 --- a/torch/onnx/_internal/fx/exporter.py +++ b/torch/onnx/_internal/fx/exporter.py @@ -1,11 +1,9 @@ from __future__ import annotations import copy -import functools import inspect import itertools import operator -import os import re import warnings from types import FunctionType @@ -75,117 +73,6 @@ onnxscript.OnnxFunction.__call__ = _diagnose_onnx_function( ) -class ModuleExpansionTracer(torch.fx._symbolic_trace.Tracer): - """Tracer to create ONNX-exporting friendly FX graph. - - This tracer traces models into operators. That is, - the traced graph mostly contains call_function nodes and - has no call_module nodes. The call_module nodes - are problematic to the use of make_fx(...) in ONNX - exporter. - """ - - @_beartype.beartype - def is_leaf_module( - self, module: torch.nn.Module, module_qualified_name: str - ) -> bool: - # This returns False so that all sub-modules are considered as not leaves - # and therefore expanded into operators in - # torch.fx._symbolic_trace.Tracer.call_module. - return False - - @_beartype.beartype - def to_bool(self, obj: "torch.fx.Proxy") -> bool: - # This is a hack to tracing through if-else Python blocks. - # It may generate incorrect ONNX graphs if the if-else block - return False - - -# Functions directly wrapped to produce torch.fx.Proxy so that symbolic -# data can flow through those functions. Python functions (e.g., `torch.arange`) -# not defined by pybind11 in C++ do not go though Python dispatcher, so -# they are not automatically patched by FX's Python dispatcher. -# The list below means `torch.arange`, `torch.tensor`, and so on will be -# patched. -_TORCH_METHODS_TO_PATCH: Tuple[str, ...] = ( - "arange", - "tensor", - "finfo", - "full", - "empty", -) - - -def _wrap_for_symbolic_trace(target: Callable) -> Tuple[Callable, Callable]: - """This function wraps ```target`` for symbolic tracing. - - This function wraps ```target``` so that its wrapper produces - torch.fx.Proxy in symbolic computation. The returned values are - the wrapper and then the original function. Per `_TORCH_METHODS_TO_PATCH`, - this function shall receive `torch.arange`, `torch.tensor`, etc. as inputs. - """ - - @functools.wraps(target) - def wrapper(*args, **kwargs): - proxy = None - - def check_has_proxy(v): - if isinstance(v, torch.fx.Proxy): - nonlocal proxy - proxy = v - - torch.fx.node.map_aggregate(args, check_has_proxy) - torch.fx.node.map_aggregate(kwargs, check_has_proxy) - - if proxy is not None: - return proxy.tracer.create_proxy("call_function", target, args, kwargs) - else: - return target(*args, **kwargs) - - return wrapper, target - - -@_beartype.beartype -def _module_expansion_symbolic_trace( - root: Union[torch.nn.Module, Callable[..., Any]], - concrete_args: Optional[Dict[str, Any]] = None, -) -> "torch.fx.GraphModule": - """Trace a callable into FX graph. - - When "root" is torch.nn.Module, calls to its submodule (type: torch.nn.Module) will be - expanded into operators (e.g., torch.matmul, torch.add, +, and -) to simplify graph - structure. - """ - # For functions doesn't support symbolic tracing, create wrappers - # which produce symbolic results during tracing. - patched_torch_methods = { - target_name: _wrap_for_symbolic_trace(getattr(torch, target_name)) - for target_name in _TORCH_METHODS_TO_PATCH - } - - # Set the symbolic-tracing friendly functions so that `tracer.trace` below - # can work. - for name, (wrapper, _) in patched_torch_methods.items(): - setattr(torch, name, wrapper) - - try: - # Set up a tracer. - tracer = ModuleExpansionTracer() - # Trace the model. - graph = tracer.trace(root, concrete_args) - name = ( - root.__class__.__name__ - if isinstance(root, torch.nn.Module) - else root.__name__ - ) - return torch.fx.GraphModule(tracer.root, graph, name) - finally: - # Revert the patches for symbolic tracing. - for name, (_, wrapped) in patched_torch_methods.items(): - # wrapped is the original version of `torch.name`. - setattr(torch, name, wrapped) - - def _retrieve_or_adapt_input_to_graph_set(fx_node_arg, fx_name_to_onnxscipt_value): """Map FX value to TorchScript value. @@ -780,309 +667,6 @@ def export_after_normalizing_args_and_kwargs( ) -@_beartype.beartype -def _move_placeholder_to_front(graph_module: "torch.fx.GraphModule") -> None: - """ - This function move all placeholder nodes to the front of the graph node list. - In torch.fx.Graph, placeholder is a special assignment node. If it's not - executed in the beginning, it could overwrite values computed by upstream - nodes. - """ - - graph = graph_module.graph - placeholders = [] - first_not_placeholder = None - for node in graph.nodes: - if node.op == "placeholder": - placeholders.append(node) - if first_not_placeholder is None and node.op != "placeholder": - first_not_placeholder = node - if first_not_placeholder is None: - return - for placeholder in placeholders: - first_not_placeholder.prepend(placeholder) - - -@_beartype.beartype -def _replace_get_attr_with_placeholder( - graph_module: "torch.fx.GraphModule", -) -> Tuple[torch.Tensor, ...]: - """ - Replace get_attr with placeholder. - The parameters and buffers accessed by the original get_attr are returned; - they are useful when creating random inputs for the modified graph_module. - """ - graph = graph_module.graph - replaced_attrs: List[torch.Tensor] = [] - for node in graph.nodes: - if node.op == "get_attr": - replaced_attr: Optional[torch.Tensor] = None - # get_attr could retrieve either parameter or buffer, so - # we need to try both. - try: - replaced_attr = graph_module.get_parameter(node.target) - except AttributeError: - # It's possible that model author use buffer instead of - # parameter to store trainable weights. In this case, - # 1. get_parameter will throw something like - # AttributeError: `bias` is not an nn.Parameter. - # 2. get_buffer should work. - replaced_attr = graph_module.get_buffer(node.target) - - # Reassign op type so that get_attr node becomes placeholder node. - node.op = "placeholder" - # The target name in placeholder must be a valid Python identifier. - # Thus, we replace, e.g., "module.submodule.weight" with - # "module_submodule_weight". - node.target = node.target.replace(".", "_") - # Default value is None. This is needed as long as the "graph_module" - # has optional inputs. Assume the original forward signature is - # def forward(self, x, y=None) - # and the replaced get_attr node has target "z". Then, the modified - # signature should be - # def forward(self, x, y=None, z=None) - # Without the following line, the signature will be - # def forward(self, x, y=None, z) - # , which is not valid Python code. - node.args = (None,) - - replaced_attrs.append(replaced_attr) - - return tuple(replaced_attrs) - - -@_beartype.beartype -def _trace_into_fx_graph_via_fx_symbolic_trace( - module: torch.nn.Module, - *args, - # kwargs are the keyword arguments to call "module"; that is, - # module(*args, **kwargs) must run. - **kwargs, -) -> Tuple["torch.fx.GraphModule", Tuple[Any, ...]]: - signature = inspect.signature(module.forward) - - # We hope the input kwargs will be mapped to bound.args after binding. - # If not, we will raise an error. - bound = signature.bind(*args, **kwargs) - bound.apply_defaults() - # After apply_defaults, all non keyword-only arguments are in bound.args. - # Because below code do not support keyword-word arguments, bound.kwargs - # must be empty. - assert len(bound.kwargs) == 0, bound.kwargs - - # Create inputs to call symbolic trace (torch.fx.symbolic_trace) - # Example content of concrete_args: - # concrete_args["x"] = torch.fx._symbolic_trace.PH - # concrete_args["b"] = 1 - # where "x" and "b" are argument names in "signature". - concrete_args = {} - for param_name, param_value in bound.arguments.items(): - if isinstance(param_value, torch.Tensor): - # param_value can be, e.g., a real tensor or a fake tensor. - # param_value is treated as substitutable tensor symbol (aka placeholder). - concrete_args[param_name] = torch.fx._symbolic_trace.PH - else: - concrete_args[param_name] = param_value - - return ( - _module_expansion_symbolic_trace(module, concrete_args=concrete_args), - bound.args, - ) - - -@_beartype.beartype -def export_without_parameters_and_buffers( - module: torch.nn.Module, - *args, - decomposition_table: Optional[Dict[torch._ops.OpOverload, Callable]] = None, - use_binary_format: bool = True, - opset_version: int = _constants.ONNX_DEFAULT_OPSET, - op_level_debug: bool = False, - # kwargs are the keyword arguments to call "module"; that is, - # module(*args, **kwargs) must run. - **kwargs, -) -> Tuple[ - Union["onnx.ModelProto", bytes], - "torch.fx.GraphModule", - Tuple[Any, ...], - Tuple[Any, ...], -]: - - graph_module, bound_args = _trace_into_fx_graph_via_fx_symbolic_trace( - module, *args, **kwargs - ) - - # Make sure all placeholder nodes are executed before get_attr nodes. - # Otherwise, inputs can interleave with initializers in the final ModeoProto.graph.input. - # Basically, we want - # ModeoProto.graph.input = - # [input_0, input_1, ..., input_n, weight_0, weight_1, ..., weight_m] - # and we don't want - # ModeoProto.graph.input = - # [input_0, weight_0, input_1, weight_1, ..., input_n, weight_0, weight_1, ..., weight_m] - _move_placeholder_to_front(graph_module) - # To save memory, move get_attr to input so that the generated model doesn't - # have weigh tensors. "replaced_attrs" are the list of replaced weight tensors. - replaced_attrs = _replace_get_attr_with_placeholder(graph_module) - # Move all newly created placeholder nodes to the front of the graph. - _move_placeholder_to_front(graph_module) - # Finalize the graph editing. - graph_module.recompile() - return ( - _export( - graph_module, - (*bound_args, *replaced_attrs), - opset_version=opset_version, - decomposition_table=decomposition_table, - use_binary_format=use_binary_format, - op_level_debug=op_level_debug, - ), - graph_module, - bound_args, - replaced_attrs, - ) - - -@_beartype.beartype -def _create_tensor_proto_with_external_data( - tensor: torch.Tensor, name: str, location: str, basepath: str -) -> "onnx.TensorProto": - """Create a TensorProto with external data from a PyTorch tensor. - The external data is saved to os.path.join(basepath, location). - - Args: - tensor: Tensor to be saved. - name: Name of the tensor (i.e., initializer name in ONNX graph). - location: Relative location of the external data file - (e.g., "/tmp/initializers/weight_0" when model is "/tmp/model_name.onnx"). - basepath: Base path of the external data file (e.g., "/tmp/external_data" while model must be in "/tmp"). - - - Reference for ONNX's external data format: - How to load? - https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L187 - How to save? - https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L43 - How to set ONNX fields? - https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L88 - """ - tensor_proto = onnx.TensorProto() - tensor_proto.name = name - tensor_proto.data_type = torch.onnx._type_utils._SCALAR_TYPE_TO_ONNX[ # type: ignore[assignment] - torch.onnx._type_utils._DTYPE_TO_SCALAR_TYPE[tensor.dtype] - ] - tensor_proto.dims.extend(tensor.shape) - tensor_proto.data_location = onnx.TensorProto.EXTERNAL - - # Settings for saving one tensor per file. - # Offset is zero because there is no other tensor in the same file. - key_value_pairs = { - "location": location, - "offset": 0, - "length": tensor.untyped_storage().nbytes(), - } - for k, v in key_value_pairs.items(): - entry = tensor_proto.external_data.add() - entry.key = k - entry.value = str(v) - - # Actual path to write content of tensor. - external_data_file_path = os.path.join(basepath, location) - if os.path.exists(external_data_file_path): - os.remove(external_data_file_path) - - # Create external data's folder if not exists. - external_data_dir_path = os.path.dirname(external_data_file_path) - if not os.path.exists(external_data_dir_path): - # if the demo_folder directory is not present - # then create it. - os.makedirs(external_data_dir_path) - - # Create a fresh file. - with open(external_data_file_path, "xb") as data_file: - # No need to call "seek" because offset is 0. - # data_file.seek(0) - # Write tensor content to the file. - data_file.write(tensor.numpy().tobytes()) - - return tensor_proto - - -@_beartype.beartype -def save_model_with_external_data( - basepath: str, - model_location: str, - initializer_location: str, - torch_load_paths: Tuple[str, ...], - onnx_model: "onnx.ModelProto", -) -> None: - """Load PyTorch tensors from files and add to "onnx_model" as external initializers. - - Output files: - ONNX model file path: - ONNX initializer folder: os.path.join(basepath, initializer_location) - - After running this function, you can do - ort_sess = onnxruntime.InferenceSession(os.path.join(basepath, model_location)) - to execute the model. - - Arguments: - basepath: Base path of the external data file (e.g., "/tmp/large-onnx-model"). - model_location: Relative location of the ONNX model file. - E.g., "model.onnx" so that the model file is saved to - "/tmp/large-onnx-model/model.onnx". - initializer_location: Relative location of the ONNX initializer folder. - E.g., "initializers" so that the initializers are saved to - "/tmp/large-onnx-model/initializers". - torch_load_paths: Files which containing serialized PyTorch tensors to be saved - as ONNX initializers. They are loaded by torch.load. - onnx_model: ONNX model to be saved with external initializers. - If an input name matches a tensor loaded from "torch_load_paths", - the tensor will be saved as that input's external initializer. - """ - onnx_model_with_initializers = onnx.ModelProto() - onnx_model_with_initializers.CopyFrom(onnx_model) - onnx_input_names = [input.name for input in onnx_model.graph.input] - - for path in torch_load_paths: - state_ditc = torch.load(path) - for name, tensor in state_ditc.items(): - # Basically, "transformer.attention.self.query.weight" is mapped - # to "transformer_attention_self_query_weight" for mimicking the - # name-modifying code in FX-to-ONNX exporter. - # See function _replace_get_attr_with_placeholder for details. - refined_name = name.replace(".", "_") - - # For each refined PyTorch tensor name loaded by torch.load, - # 1. Search its best match in ONNX model. E.g., the match of - # "transformer_attention_weight" could be "attention_weight". - # 2. Set "tensor" as the initializer of the matched ONNX input. - # E.g., "tensor" is stored as the initializer of "attention_weight". - # Step 1 is required because sometimes, tensor names are stored with prefix the dictionary - # loaded by torch.load. - for onnx_input_name in onnx_input_names: - if onnx_input_name.endswith(refined_name) or refined_name.endswith( - onnx_input_name - ): - # Find a match. Change refined_name to the matched ONNX input name, so that we - # create initializer with the right ONNX name. - refined_name = onnx_input_name - break - - relative_tensor_file_path = os.path.join(initializer_location, refined_name) - # Create one file per tensor. - # tensor_proto.raw_data is stored to external file at - # os.path.join(basepath, relative_tensor_file_path). - tensor_proto = _create_tensor_proto_with_external_data( - tensor, refined_name, relative_tensor_file_path, basepath - ) - # Add the tensor_proto to the ONNX model as an initializer with external data. - onnx_model_with_initializers.graph.initializer.append(tensor_proto) - - # model_location should be a pure file name such as "file_name.onnx", not "folder/file_name.onnx". - onnx.save(onnx_model_with_initializers, os.path.join(basepath, model_location)) - - @_beartype.beartype def _validate_op_between_ort_torch( node: torch.fx.Node, diff --git a/torch/onnx/_internal/fx/serialization.py b/torch/onnx/_internal/fx/serialization.py new file mode 100644 index 00000000000..75aba61edba --- /dev/null +++ b/torch/onnx/_internal/fx/serialization.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +import os +from typing import Tuple + +import onnx + +import torch +from torch.onnx._internal import _beartype + + +@_beartype.beartype +def _create_tensor_proto_with_external_data( + tensor: torch.Tensor, name: str, location: str, basepath: str +) -> "onnx.TensorProto": + """Create a TensorProto with external data from a PyTorch tensor. + The external data is saved to os.path.join(basepath, location). + + Args: + tensor: Tensor to be saved. + name: Name of the tensor (i.e., initializer name in ONNX graph). + location: Relative location of the external data file + (e.g., "/tmp/initializers/weight_0" when model is "/tmp/model_name.onnx"). + basepath: Base path of the external data file (e.g., "/tmp/external_data" while model must be in "/tmp"). + + + Reference for ONNX's external data format: + How to load? + https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L187 + How to save? + https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L43 + How to set ONNX fields? + https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L88 + """ + tensor_proto = onnx.TensorProto() + tensor_proto.name = name + tensor_proto.data_type = torch.onnx._type_utils._SCALAR_TYPE_TO_ONNX[ # type: ignore[assignment] + torch.onnx._type_utils._DTYPE_TO_SCALAR_TYPE[tensor.dtype] + ] + tensor_proto.dims.extend(tensor.shape) + tensor_proto.data_location = onnx.TensorProto.EXTERNAL + + # Settings for saving one tensor per file. + # Offset is zero because there is no other tensor in the same file. + key_value_pairs = { + "location": location, + "offset": 0, + "length": tensor.untyped_storage().nbytes(), + } + for k, v in key_value_pairs.items(): + entry = tensor_proto.external_data.add() + entry.key = k + entry.value = str(v) + + # Actual path to write content of tensor. + external_data_file_path = os.path.join(basepath, location) + if os.path.exists(external_data_file_path): + os.remove(external_data_file_path) + + # Create external data's folder if not exists. + external_data_dir_path = os.path.dirname(external_data_file_path) + if not os.path.exists(external_data_dir_path): + # if the demo_folder directory is not present + # then create it. + os.makedirs(external_data_dir_path) + + # Create a fresh file. + with open(external_data_file_path, "xb") as data_file: + # No need to call "seek" because offset is 0. + # data_file.seek(0) + # Write tensor content to the file. + data_file.write(tensor.numpy().tobytes()) + + return tensor_proto + + +@_beartype.beartype +def save_model_with_external_data( + basepath: str, + model_location: str, + initializer_location: str, + torch_load_paths: Tuple[str, ...], + onnx_model: onnx.ModelProto, +) -> None: + """Load PyTorch tensors from files and add to "onnx_model" as external initializers. + + Output files: + ONNX model file path: + ONNX initializer folder: os.path.join(basepath, initializer_location) + + After running this function, you can do + ort_sess = onnxruntime.InferenceSession(os.path.join(basepath, model_location)) + to execute the model. + + Arguments: + basepath: Base path of the external data file (e.g., "/tmp/large-onnx-model"). + model_location: Relative location of the ONNX model file. + E.g., "model.onnx" so that the model file is saved to + "/tmp/large-onnx-model/model.onnx". + initializer_location: Relative location of the ONNX initializer folder. + E.g., "initializers" so that the initializers are saved to + "/tmp/large-onnx-model/initializers". + torch_load_paths: Files which containing serialized PyTorch tensors to be saved + as ONNX initializers. They are loaded by torch.load. + onnx_model: ONNX model to be saved with external initializers. + If an input name matches a tensor loaded from "torch_load_paths", + the tensor will be saved as that input's external initializer. + """ + onnx_model_with_initializers = onnx.ModelProto() + onnx_model_with_initializers.CopyFrom(onnx_model) + onnx_input_names = [input.name for input in onnx_model.graph.input] + + for path in torch_load_paths: + state_ditc = torch.load(path) + for name, tensor in state_ditc.items(): + # Basically, "transformer.attention.self.query.weight" is mapped + # to "transformer_attention_self_query_weight" for mimicking the + # name-modifying code in FX-to-ONNX exporter. + # See function _replace_get_attr_with_placeholder for details. + refined_name = name.replace(".", "_") + + # For each refined PyTorch tensor name loaded by torch.load, + # 1. Search its best match in ONNX model. E.g., the match of + # "transformer_attention_weight" could be "attention_weight". + # 2. Set "tensor" as the initializer of the matched ONNX input. + # E.g., "tensor" is stored as the initializer of "attention_weight". + # Step 1 is required because sometimes, tensor names are stored with prefix the dictionary + # loaded by torch.load. + for onnx_input_name in onnx_input_names: + if onnx_input_name.endswith(refined_name) or refined_name.endswith( + onnx_input_name + ): + # Find a match. Change refined_name to the matched ONNX input name, so that we + # create initializer with the right ONNX name. + refined_name = onnx_input_name + break + + relative_tensor_file_path = os.path.join(initializer_location, refined_name) + # Create one file per tensor. + # tensor_proto.raw_data is stored to external file at + # os.path.join(basepath, relative_tensor_file_path). + tensor_proto = _create_tensor_proto_with_external_data( + tensor, refined_name, relative_tensor_file_path, basepath + ) + # Add the tensor_proto to the ONNX model as an initializer with external data. + onnx_model_with_initializers.graph.initializer.append(tensor_proto) + + # model_location should be a pure file name such as "file_name.onnx", not "folder/file_name.onnx". + onnx.save(onnx_model_with_initializers, os.path.join(basepath, model_location)) diff --git a/torch/onnx/_internal/fx/symbolic_exporter.py b/torch/onnx/_internal/fx/symbolic_exporter.py new file mode 100644 index 00000000000..a6e05253a53 --- /dev/null +++ b/torch/onnx/_internal/fx/symbolic_exporter.py @@ -0,0 +1,289 @@ +from __future__ import annotations + +import functools + +import inspect + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import onnx + +import torch +import torch.fx + +from torch.onnx import _constants +from torch.onnx._internal import _beartype +from torch.onnx._internal.fx import exporter + +# Functions directly wrapped to produce torch.fx.Proxy so that symbolic +# data can flow through those functions. Python functions (e.g., `torch.arange`) +# not defined by pybind11 in C++ do not go though Python dispatcher, so +# they are not automatically patched by FX's Python dispatcher. +# The list below means `torch.arange`, `torch.tensor`, and so on will be +# patched. +_TORCH_METHODS_TO_PATCH: Tuple[str, ...] = ( + "arange", + "tensor", + "finfo", + "full", + "empty", +) + + +class ModuleExpansionTracer(torch.fx._symbolic_trace.Tracer): + """Tracer to create ONNX-exporting friendly FX graph. + + This tracer traces models into operators. That is, + the traced graph mostly contains call_function nodes and + has no call_module nodes. The call_module nodes + are problematic to the use of make_fx(...) in ONNX + exporter. + """ + + @_beartype.beartype + def is_leaf_module( + self, module: torch.nn.Module, module_qualified_name: str + ) -> bool: + # This returns False so that all sub-modules are considered as not leaves + # and therefore expanded into operators in + # torch.fx._symbolic_trace.Tracer.call_module. + return False + + @_beartype.beartype + def to_bool(self, obj: "torch.fx.Proxy") -> bool: + # FIXME: This is a hack to tracing through if-else Python blocks. + # It may generate incorrect ONNX graphs if the if-else block + return False + + +@_beartype.beartype +def _move_placeholder_to_front(graph_module: torch.fx.GraphModule) -> None: + """ + This function move all placeholder nodes to the front of the graph node list. + In torch.fx.Graph, placeholder is a special assignment node. If it's not + executed in the beginning, it could overwrite values computed by upstream + nodes. + """ + + graph = graph_module.graph + placeholders = [] + first_not_placeholder = None + for node in graph.nodes: + if node.op == "placeholder": + placeholders.append(node) + if first_not_placeholder is None and node.op != "placeholder": + first_not_placeholder = node + if first_not_placeholder is None: + return + for placeholder in placeholders: + first_not_placeholder.prepend(placeholder) + + +@_beartype.beartype +def _replace_get_attr_with_placeholder( + graph_module: torch.fx.GraphModule, +) -> Tuple[torch.Tensor, ...]: + """ + Replace get_attr with placeholder. + The parameters and buffers accessed by the original get_attr are returned; + they are useful when creating random inputs for the modified graph_module. + """ + graph = graph_module.graph + replaced_attrs: List[torch.Tensor] = [] + for node in graph.nodes: + if node.op == "get_attr": + replaced_attr: Optional[torch.Tensor] = None + # get_attr could retrieve either parameter or buffer, so + # we need to try both. + try: + replaced_attr = graph_module.get_parameter(node.target) + except AttributeError: + # It's possible that model author use buffer instead of + # parameter to store trainable weights. In this case, + # 1. get_parameter will throw something like + # AttributeError: `bias` is not an nn.Parameter. + # 2. get_buffer should work. + replaced_attr = graph_module.get_buffer(node.target) + + # Reassign op type so that get_attr node becomes placeholder node. + node.op = "placeholder" + # The target name in placeholder must be a valid Python identifier. + # Thus, we replace, e.g., "module.submodule.weight" with + # "module_submodule_weight". + node.target = node.target.replace(".", "_") + # Default value is None. This is needed as long as the "graph_module" + # has optional inputs. Assume the original forward signature is + # def forward(self, x, y=None) + # and the replaced get_attr node has target "z". Then, the modified + # signature should be + # def forward(self, x, y=None, z=None) + # Without the following line, the signature will be + # def forward(self, x, y=None, z) + # , which is not valid Python code. + node.args = (None,) + + replaced_attrs.append(replaced_attr) + + return tuple(replaced_attrs) + + +@_beartype.beartype +def _trace_into_fx_graph_via_fx_symbolic_trace( + module: torch.nn.Module, + *args, + # kwargs are the keyword arguments to call "module"; that is, + # module(*args, **kwargs) must run. + **kwargs, +) -> Tuple["torch.fx.GraphModule", Tuple[Any, ...]]: + signature = inspect.signature(module.forward) + + # We hope the input kwargs will be mapped to bound.args after binding. + # If not, we will raise an error. + bound = signature.bind(*args, **kwargs) + bound.apply_defaults() + # After apply_defaults, all non keyword-only arguments are in bound.args. + # Because below code do not support keyword-word arguments, bound.kwargs + # must be empty. + assert len(bound.kwargs) == 0, bound.kwargs + + # Create inputs to call symbolic trace (torch.fx.symbolic_trace) + # Example content of concrete_args: + # concrete_args["x"] = torch.fx._symbolic_trace.PH + # concrete_args["b"] = 1 + # where "x" and "b" are argument names in "signature". + concrete_args = {} + for param_name, param_value in bound.arguments.items(): + if isinstance(param_value, torch.Tensor): + # param_value can be, e.g., a real tensor or a fake tensor. + # param_value is treated as substitutable tensor symbol (aka placeholder). + concrete_args[param_name] = torch.fx._symbolic_trace.PH + else: + concrete_args[param_name] = param_value + + return ( + _module_expansion_symbolic_trace(module, concrete_args=concrete_args), + bound.args, + ) + + +def _wrap_for_symbolic_trace(target: Callable) -> Tuple[Callable, Callable]: + """This function wraps ```target`` for symbolic tracing. + + This function wraps ```target``` so that its wrapper produces + torch.fx.Proxy in symbolic computation. The returned values are + the wrapper and then the original function. Per `_TORCH_METHODS_TO_PATCH`, + this function shall receive `torch.arange`, `torch.tensor`, etc. as inputs. + """ + + @functools.wraps(target) + def wrapper(*args, **kwargs): + proxy = None + + def check_has_proxy(v): + if isinstance(v, torch.fx.Proxy): + nonlocal proxy + proxy = v + + torch.fx.node.map_aggregate(args, check_has_proxy) + torch.fx.node.map_aggregate(kwargs, check_has_proxy) + + if proxy is not None: + return proxy.tracer.create_proxy("call_function", target, args, kwargs) + else: + return target(*args, **kwargs) + + return wrapper, target + + +@_beartype.beartype +def _module_expansion_symbolic_trace( + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None, +) -> torch.fx.GraphModule: + """Trace a callable into FX graph. + + When "root" is torch.nn.Module, calls to its submodule (type: torch.nn.Module) will be + expanded into operators (e.g., torch.matmul, torch.add, +, and -) to simplify graph + structure. + """ + # For functions doesn't support symbolic tracing, create wrappers + # which produce symbolic results during tracing. + patched_torch_methods = { + target_name: _wrap_for_symbolic_trace(getattr(torch, target_name)) + for target_name in _TORCH_METHODS_TO_PATCH + } + + # Set the symbolic-tracing friendly functions so that `tracer.trace` below + # can work. + for name, (wrapper, _) in patched_torch_methods.items(): + setattr(torch, name, wrapper) + + try: + # Set up a tracer. + tracer = ModuleExpansionTracer() + # Trace the model. + graph = tracer.trace(root, concrete_args) + name = ( + root.__class__.__name__ + if isinstance(root, torch.nn.Module) + else root.__name__ + ) + return torch.fx.GraphModule(tracer.root, graph, name) + finally: + # Revert the patches for symbolic tracing. + for name, (_, wrapped) in patched_torch_methods.items(): + # wrapped is the original version of `torch.name`. + setattr(torch, name, wrapped) + + +@_beartype.beartype +def export_without_parameters_and_buffers( + module: torch.nn.Module, + *args, + decomposition_table: Optional[Dict[torch._ops.OpOverload, Callable]] = None, + use_binary_format: bool = True, + opset_version: int = _constants.ONNX_DEFAULT_OPSET, + op_level_debug: bool = False, + # kwargs are the keyword arguments to call "module"; that is, + # module(*args, **kwargs) must run. + **kwargs, +) -> Tuple[ + Union[onnx.ModelProto, bytes], + torch.fx.GraphModule, + Tuple[Any, ...], + Tuple[Any, ...], +]: + + graph_module, bound_args = _trace_into_fx_graph_via_fx_symbolic_trace( + module, *args, **kwargs + ) + + # Make sure all placeholder nodes are executed before get_attr nodes. + # Otherwise, inputs can interleave with initializers in the final ModeoProto.graph.input. + # Basically, we want + # ModeoProto.graph.input = + # [input_0, input_1, ..., input_n, weight_0, weight_1, ..., weight_m] + # and we don't want + # ModeoProto.graph.input = + # [input_0, weight_0, input_1, weight_1, ..., input_n, weight_0, weight_1, ..., weight_m] + _move_placeholder_to_front(graph_module) + # To save memory, move get_attr to input so that the generated model doesn't + # have weigh tensors. "replaced_attrs" are the list of replaced weight tensors. + replaced_attrs = _replace_get_attr_with_placeholder(graph_module) + # Move all newly created placeholder nodes to the front of the graph. + _move_placeholder_to_front(graph_module) + # Finalize the graph editing. + graph_module.recompile() + return ( + exporter._export( + graph_module, + (*bound_args, *replaced_attrs), + opset_version=opset_version, + decomposition_table=decomposition_table, + use_binary_format=use_binary_format, + op_level_debug=op_level_debug, + ), + graph_module, + bound_args, + replaced_attrs, + )