pytorch/torch/onnx/_internal/fx/serialization.py
Thiago Crepaldi 32a67e42c4 Introduce FXGraphExtractor into torch.onnx.dynamo_export (#99940)
The current API architecture can be seen as 3 independent exporters as shown below. The public API `dynamo_export()` defaults to one of the 3 variants and the other 2 must be used by instantiating private classes: ![image](https://user-images.githubusercontent.com/5469809/231567368-ec899718-b7c1-4e59-b6a8-383142df245a.png)

This PR refactors the API in a way that `dynamo_export` is the only way to use the ONNX exporter. It defaults to a FX tracer based on ``torch.export``, but an internal-only idiom allows switching the FX tracer (aka `FXGraphExtractor` interface), as shown below:

![image](https://user-images.githubusercontent.com/5469809/231567495-3936362d-06de-4cfc-b752-6c2060701c08.png)

Summary of changes:

* Unifies all exporter variants under a single `dynamo_export` API
  * `ResolvedExportOptions` was expanded to allow `fx_tracer: FXGraphExtractor` to be specified, selecting which FX graph extractor to use, according to the design proposal
  * As a consequence, `torch.onnx._internal.exporter.Exporter` does not have to *internally* specialize for each type of FX API that the exporter might be used. This leads to a single `Exporter` with many `FX graph extractors`
  * Before in red, after in green: ![image](https://user-images.githubusercontent.com/5469809/232633531-4c67449b-4863-474d-9e18-78fc1d31b1bd.png)
* Input processing was moved from `Exporter` subclasses to `FXGraphExtractor` subclasses, where they are actually consumed
  * `Exporter` is a [data]class that holds export options, model and input data in a single cohesive object. Specializing it means create different exporters instead of having one exporter capable of exporting models through different options.
  * `Exporter` doesn't consume the `model_args` that caused it to specialize
* Improved the circular dependency story.
  * https://github.com/pytorch/pytorch/pull/99070 moves `import torch.onnx` to after all dynamo subcomponents, preventing `torch.onnx` to have circular depemndencies when `torch.XXXX` is imported during initialization
  * There are other points we need to improve in subsequent PRs. APIs are organized in a way that it is easy to "import too much"
* Refactored `decomposition_table` as an internal-only `ResolvedExportOptions` property.
  * Similar to input processing, this helper is not actually consumed at tyhe `Exporter` layer. This PR moves it to the layer in which it is used
* Demoted `Exporter.model_signature` to a simple standalone helper
  * There is no need to have this as a exporter method; this is a standard `inpect.signature` usage without any state

Possible next steps are:
* Decouple `passes` and `dispatching` from the cluttered `export_fx_to_onnx`
* Further integration with http://github.com/pytorch/pytorch/pull/98421/ into `FXGraphExtractor` public API + helper for unit testing
  * Some passes are changing input processing, which are not captured by the proposed input adapter

** COPILOT SUMMARY**
<!--
copilot:all
-->
### <samp>🤖 Generated by Copilot at bdaba31</samp>

### Summary
📝🚀🔧

<!--
1.  📝 - This emoji represents the formatting and documentation changes, such as adding an empty line, updating the `__all__` list, and improving the type annotations and docstrings.
2.  🚀 - This emoji represents the new features and enhancements, such as adding the `DynamoExport` class, supporting custom export options, and flattening HuggingFace model outputs.
3.  🔧 - This emoji represents the refactoring and restructuring changes, such as using the FX graph representation, the `io_adapter` module, and the simplified FX symbolic tracer, and renaming and reorganizing some modules and classes.
-->
This pull request refactors the ONNX exporter code to use the FX graph representation and the new `io_adapter` module for input and output adaptation. It also adds support for custom export options and flattening HuggingFace model outputs in the ONNX test framework. It updates the ONNX dynamo exporter API tests and adds a new module `torch/onnx/_internal/fx/dynamo_graph_extractor.py` for exporting FX models to ONNX with dynamo support. It fixes some type annotations, imports, and formatting issues in the ONNX exporter code.

> _The ONNX exporter got a new look_
> _With FX graph and dynamo hook_
> _It uses `io_adapter`_
> _And custom options matter_
> _For HuggingFace models and `model_signature` book_

### Walkthrough
*  Move the `fx` submodule from `torch/onnx/_internal` to `torch/onnx/_internal/fx`, and rename some of its modules ( [link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-c8fa56eefd7f98fb4f9739d57df57f02ede77e28528133736010a6d06651ebcbL21-R26), [link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-0795f54fd1f38cfbf2c4a863a4efc9f40f2ea020a2b1612605c361b8d8d35862L25-R26), [link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-3eef404cb9d85216c050be153c33255ebce1170a77d8b9b17be79bcfb238c9c4L5-R15), [link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-4da17ba9e1a187bfacb65a70d6ff15f6c2a60480be8e20fc452d8984a279cd0aL3-R30))
*  Add a new module `torch/onnx/_internal/fx/dynamo_graph_extractor.py` that defines a `DynamoExport` class for generating FX graphs using the `torch._dynamo.export` API ([link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-078d7b8d0e4050e650fc3c15dc97a0564852191ac7b7bdc069d0b3959c5ee39aR1-R77))
*  Add a new module `torch/onnx/_internal/fx/io_adapter.py` that defines the input and output adapter classes and steps for the ONNX exporter, and a helper function to wrap models with output adapters ([link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-0795f54fd1f38cfbf2c4a863a4efc9f40f2ea020a2b1612605c361b8d8d35862L159-R192), [link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-4da17ba9e1a187bfacb65a70d6ff15f6c2a60480be8e20fc452d8984a279cd0aL3-R30), [link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-4da17ba9e1a187bfacb65a70d6ff15f6c2a60480be8e20fc452d8984a279cd0aR72-R176), [link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-4da17ba9e1a187bfacb65a70d6ff15f6c2a60480be8e20fc452d8984a279cd0aL237-R478))
*  Update the `ResolvedExportOptions` class in `torch/onnx/_internal/exporter.py` to inherit from the `ExportOptions` class, and to set the `fx_tracer` and `decomposition_table` attributes based on the `dynamo_graph_extractor` and `function_dispatcher` modules ([link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-0795f54fd1f38cfbf2c4a863a4efc9f40f2ea020a2b1612605c361b8d8d35862L81-R99), [link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-0795f54fd1f38cfbf2c4a863a4efc9f40f2ea020a2b1612605c361b8d8d35862R117-R126))
*  Update the `Exporter` class in `torch/onnx/_internal/exporter.py` to remove the `export` method and add a new abstract `generate_fx` method, and to use the `fx_tracer` attribute to generate and export the FX graph ([link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-0795f54fd1f38cfbf2c4a863a4efc9f40f2ea020a2b1612605c361b8d8d35862L413-R475), [link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-0795f54fd1f38cfbf2c4a863a4efc9f40f2ea020a2b1612605c361b8d8d35862L422-R486))
*  Update the `FXSymbolicTraceExporter` class in `torch/onnx/_internal/fx/fx_symbolic_graph_extractor.py` to be renamed to `FXSymbolicTracer`, and to inherit from `exporter.FXGraphExtractor` and implement the `generate_fx` method ([link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-3eef404cb9d85216c050be153c33255ebce1170a77d8b9b17be79bcfb238c9c4L128-R175), [link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-3eef404cb9d85216c050be153c33255ebce1170a77d8b9b17be79bcfb238c9c4L157-R219))
*  Update the `export_fx_to_onnx` method of the `FXSymbolicTracer` class to be renamed to `_export_fx_to_onnx`, and to be moved to the `exporter.FXGraphExtractor` class ([link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-3eef404cb9d85216c050be153c33255ebce1170a77d8b9b17be79bcfb238c9c4L193-R234))
*  Update the `dynamo_export` function in `torch/onnx/_internal/exporter.py` to accept and return `ResolvedExportOptions` and `Exporter` objects, respectively ([link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-0795f54fd1f38cfbf2c4a863a4efc9f40f2ea020a2b1612605c361b8d8d35862L536-R606))
*  Update the `run_test_with_fx_to_onnx_exporter_and_onnx_runtime` function in `test/onnx/onnx_test_common.py` to add a new parameter `export_options` for passing custom export options to the `torch.onnx.dynamo_export` function ([link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-1b38383dc1a0228a835d83bb7c4ba2d0c1bcd41297be5c6336572c525846166eR176), [link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-1b38383dc1a0228a835d83bb7c4ba2d0c1bcd41297be5c6336572c525846166eL216-R222))
*  Update the `test_log_sigmoid` and `_test_large_scale_exporter` tests in `test/onnx/test_fx_to_onnx_with_onnxruntime.py` to use the updated `run_test_with_fx_to_onnx_exporter_and_onnx_runtime` function and the `torch.onnx.dynamo_export` function ([link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-c8fa56eefd7f98fb4f9739d57df57f02ede77e28528133736010a6d06651ebcbL297-R301), [link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-c8fa56eefd7f98fb4f9739d57df57f02ede77e28528133736010a6d06651ebcbL682-R686), [link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-c8fa56eefd7f98fb4f9739d57df57f02ede77e28528133736010a6d06651ebcbL696-R716), [link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-c8fa56eefd7f98fb4f9739d57df57f02ede77e28528133736010a6d06651ebcbL721-R730))
*  Update the `test_raise_on_invalid_save_argument_type` test in `test/onnx/dynamo/test_exporter_api.py` to use the `io_adapter.InputAdapter` and `io_adapter.OutputAdapter` classes instead of the `exporter.InputAdapter` and `exporter.OutputAdapter` classes ([link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-4545f0c15c73ebe90a875e9bee6c5ca4b6b92fb1ed0ec5560d1568e0f6339d02L139-R139))
*  Move the `model_signature` property from the `Exporter` class in `torch/onnx/_internal/exporter.py` to a standalone function in `torch/onnx/utils.py`, and update the references to it ([link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-0795f54fd1f38cfbf2c4a863a4efc9f40f2ea020a2b1612605c361b8d8d35862L432-R505), [link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-3eef404cb9d85216c050be153c33255ebce1170a77d8b9b17be79bcfb238c9c4L157-R219), [link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-849a5778e2dcf7f36587967273cee0bf20642e35bf4c79405111ea3417c3fb3cL54-R75))
*  Move the `UnsatisfiedDependencyError` class from the `Exporter` class in `torch/onnx/_internal/exporter.py` to the top level of the module, and update the references to it ([link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-0795f54fd1f38cfbf2c4a863a4efc9f40f2ea020a2b1612605c361b8d8d35862L442-R512))
*  Rename the `_create_onnx_friendly_decomposition_table` function and the `_ONNX_FRIENDLY_DECOMPOSITION_TABLE` dictionary in `torch/onnx/_internal/fx/function_dispatcher.py` to `_create_default_onnx_decomposition_table` and `_DEFAULT_ONNX_EXPORTER_DECOMPOSITION_TABLE`, respectively, and update the references to them ([link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-549890bc593f917c4e62c4c43077340e4774c0abdf31657ced8450fdfbed3b3eL213-R219), [link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-549890bc593f917c4e62c4c43077340e4774c0abdf31657ced8450fdfbed3b3eL231-R239))
*  Update the imports in `torch/onnx/_internal/fx/function_dispatcher.py` to use the `torch._ops` and `torch._decomp` modules instead of the `torch.ops` and `torch.decomp` modules, and to use aliases for accessing the `onnxscript.function_libs.torch_aten.ops` and `torch._ops` modules ([link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-549890bc593f917c4e62c4c43077340e4774c0abdf31657ced8450fdfbed3b3eL11-R16), [link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-549890bc593f917c4e62c4c43077340e4774c0abdf31657ced8450fdfbed3b3eL35-R156), [link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-549890bc593f917c4e62c4c43077340e4774c0abdf31657ced8450fdfbed3b3eL160-R166), [link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-549890bc593f917c4e62c4c43077340e4774c0abdf31657ced8450fdfbed3b3eL173-R182), [link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-549890bc593f917c4e62c4c43077340e4774c0abdf31657ced8450fdfbed3b3eL189-R194), [link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-549890bc593f917c4e62c4c43077340e4774c0abdf31657ced8450fdfbed3b3eL201-R204), [link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-549890bc593f917c4e62c4c43077340e4774c0abdf31657ced8450fdfbed3b3eL231-R239))
*  Update the `ExportOutput` class in `torch/onnx/_internal/exporter.py` to use the `InputAdapter` and `OutputAdapter` classes from `io_adapter` instead of the ones defined in the same module ([link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-0795f54fd1f38cfbf2c4a863a4efc9f40f2ea020a2b1612605c361b8d8d35862L275-R199))
*  Update the type annotations in `torch/onnx/_internal/fx/serialization.py` and `torch/onnx/_internal/exporter.py` to fix some inconsistencies ([link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-0c7a4333620a22a5c3e5315e30272b59fb7a11b393cb42f8255070bedeb02738L15-R15), [link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-0c7a4333620a22a5c3e5315e30272b59fb7a11b393cb42f8255070bedeb02738L83-R83), [link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-0795f54fd1f38cfbf2c4a863a4efc9f40f2ea020a2b1612605c361b8d8d35862L11-R11), [link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-0795f54fd1f38cfbf2c4a863a4efc9f40f2ea020a2b1612605c361b8d8d35862R18))
*  Remove an unused import of `inspect` from `torch/onnx/_internal/exporter.py` ([link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-0795f54fd1f38cfbf2c4a863a4efc9f40f2ea020a2b1612605c361b8d8d35862L5))
*  Remove an unused import of `torch._dynamo` from `torch/onnx/_internal/fx/passes/shape_inference.py` ([link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-d38827b1f79525963c39e5c480240cd81f4edcaf8b3bd374a1c6ee2fdb28b334L7))
*  Add a comment to `torch/onnx/_internal/fx/passes/shape_inference.py` to explain why the import of `torch._dynamo` is done inside the `_run` method of the `ShapeInferenceWithFakeTensor` class ([link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-d38827b1f79525963c39e5c480240cd81f4edcaf8b3bd374a1c6ee2fdb28b334R32-R35))
*  Fix a typo in the docstring of the `_module_expansion_symbolic_trace` function in `torch/onnx/_internal/fx/fx_symbolic_graph_extractor.py` ([link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-3eef404cb9d85216c050be153c33255ebce1170a77d8b9b17be79bcfb238c9c4L96-R98))
*  Add an empty line to `torch/onnx/__init__.py` for formatting purposes ([link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-c3c8c09b65c1235ca4494633c6a0aab2761a11a7653ddaf9f874bbcd91e15553R12))
*  Delete the `torch/onnx/_internal/fx/__init__.py` file ([link](https://github.com/pytorch/pytorch/pull/99940/files?diff=unified&w=0#diff-a39fa3741f027bb9717388fc922d1e846fbd43d44f2c5fbee4e8d2188a7edb85))

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99940
Approved by: https://github.com/BowenBao, https://github.com/jansel
2023-04-27 00:25:28 +00:00

150 lines
6.5 KiB
Python

from __future__ import annotations
import os
from typing import Tuple
import onnx
import torch
from torch.onnx._internal import _beartype
@_beartype.beartype
def _create_tensor_proto_with_external_data(
tensor: torch.Tensor, name: str, location: str, basepath: str
) -> onnx.TensorProto:
"""Create a TensorProto with external data from a PyTorch tensor.
The external data is saved to os.path.join(basepath, location).
Args:
tensor: Tensor to be saved.
name: Name of the tensor (i.e., initializer name in ONNX graph).
location: Relative location of the external data file
(e.g., "/tmp/initializers/weight_0" when model is "/tmp/model_name.onnx").
basepath: Base path of the external data file (e.g., "/tmp/external_data" while model must be in "/tmp").
Reference for ONNX's external data format:
How to load?
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L187
How to save?
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L43
How to set ONNX fields?
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L88
"""
tensor_proto = onnx.TensorProto()
tensor_proto.name = name
tensor_proto.data_type = torch.onnx._type_utils._SCALAR_TYPE_TO_ONNX[ # type: ignore[assignment]
torch.onnx._type_utils._DTYPE_TO_SCALAR_TYPE[tensor.dtype]
]
tensor_proto.dims.extend(tensor.shape)
tensor_proto.data_location = onnx.TensorProto.EXTERNAL
# Settings for saving one tensor per file.
# Offset is zero because there is no other tensor in the same file.
key_value_pairs = {
"location": location,
"offset": 0,
"length": tensor.untyped_storage().nbytes(),
}
for k, v in key_value_pairs.items():
entry = tensor_proto.external_data.add()
entry.key = k
entry.value = str(v)
# Actual path to write content of tensor.
external_data_file_path = os.path.join(basepath, location)
if os.path.exists(external_data_file_path):
os.remove(external_data_file_path)
# Create external data's folder if not exists.
external_data_dir_path = os.path.dirname(external_data_file_path)
if not os.path.exists(external_data_dir_path):
# if the demo_folder directory is not present
# then create it.
os.makedirs(external_data_dir_path)
# Create a fresh file.
with open(external_data_file_path, "xb") as data_file:
# No need to call "seek" because offset is 0.
# data_file.seek(0)
# Write tensor content to the file.
data_file.write(tensor.numpy().tobytes())
return tensor_proto
@_beartype.beartype
def save_model_with_external_data(
basepath: str,
model_location: str,
initializer_location: str,
torch_load_paths: Tuple[str, ...],
onnx_model: onnx.ModelProto,
) -> None:
"""Load PyTorch tensors from files and add to "onnx_model" as external initializers.
Output files:
ONNX model file path:
ONNX initializer folder: os.path.join(basepath, initializer_location)
After running this function, you can do
ort_sess = onnxruntime.InferenceSession(os.path.join(basepath, model_location))
to execute the model.
Arguments:
basepath: Base path of the external data file (e.g., "/tmp/large-onnx-model").
model_location: Relative location of the ONNX model file.
E.g., "model.onnx" so that the model file is saved to
"/tmp/large-onnx-model/model.onnx".
initializer_location: Relative location of the ONNX initializer folder.
E.g., "initializers" so that the initializers are saved to
"/tmp/large-onnx-model/initializers".
torch_load_paths: Files which containing serialized PyTorch tensors to be saved
as ONNX initializers. They are loaded by torch.load.
onnx_model: ONNX model to be saved with external initializers.
If an input name matches a tensor loaded from "torch_load_paths",
the tensor will be saved as that input's external initializer.
"""
onnx_model_with_initializers = onnx.ModelProto()
onnx_model_with_initializers.CopyFrom(onnx_model)
onnx_input_names = [input.name for input in onnx_model.graph.input]
for path in torch_load_paths:
state_ditc = torch.load(path)
for name, tensor in state_ditc.items():
# Basically, "transformer.attention.self.query.weight" is mapped
# to "transformer_attention_self_query_weight" for mimicking the
# name-modifying code in FX-to-ONNX exporter.
# See function _replace_get_attr_with_placeholder for details.
refined_name = name.replace(".", "_")
# For each refined PyTorch tensor name loaded by torch.load,
# 1. Search its best match in ONNX model. E.g., the match of
# "transformer_attention_weight" could be "attention_weight".
# 2. Set "tensor" as the initializer of the matched ONNX input.
# E.g., "tensor" is stored as the initializer of "attention_weight".
# Step 1 is required because sometimes, tensor names are stored with prefix the dictionary
# loaded by torch.load.
for onnx_input_name in onnx_input_names:
if onnx_input_name.endswith(refined_name) or refined_name.endswith(
onnx_input_name
):
# Find a match. Change refined_name to the matched ONNX input name, so that we
# create initializer with the right ONNX name.
refined_name = onnx_input_name
break
relative_tensor_file_path = os.path.join(initializer_location, refined_name)
# Create one file per tensor.
# tensor_proto.raw_data is stored to external file at
# os.path.join(basepath, relative_tensor_file_path).
tensor_proto = _create_tensor_proto_with_external_data(
tensor, refined_name, relative_tensor_file_path, basepath
)
# Add the tensor_proto to the ONNX model as an initializer with external data.
onnx_model_with_initializers.graph.initializer.append(tensor_proto)
# model_location should be a pure file name such as "file_name.onnx", not "folder/file_name.onnx".
onnx.save(onnx_model_with_initializers, os.path.join(basepath, model_location))