mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Remove legacy dynamo graph extractor (#158262)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158262 Approved by: https://github.com/justinchuby ghstack dependencies: #158258
This commit is contained in:
parent
19625daf88
commit
205241a0d5
|
|
@ -3,31 +3,18 @@ from __future__ import annotations
|
|||
|
||||
|
||||
__all__ = [
|
||||
"ExportOptions",
|
||||
"ONNXRuntimeOptions",
|
||||
"OnnxRegistry",
|
||||
"enable_fake_mode",
|
||||
]
|
||||
|
||||
|
||||
import abc
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import logging
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, TYPE_CHECKING
|
||||
from typing_extensions import deprecated
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch._ops
|
||||
from torch.onnx._internal._lazy_import import onnxscript_apis
|
||||
from torch.onnx._internal.exporter import _constants
|
||||
from torch.onnx._internal.fx import (
|
||||
decomposition_table,
|
||||
patcher as patcher,
|
||||
registration,
|
||||
)
|
||||
from torch.onnx._internal.fx import patcher as patcher
|
||||
|
||||
|
||||
# We can only import onnx from this module in a type-checking context to ensure that
|
||||
|
|
@ -35,10 +22,6 @@ from torch.onnx._internal.fx import (
|
|||
# 'import onnx' inside of dynamo_export (by way of _assert_dependencies).
|
||||
if TYPE_CHECKING:
|
||||
import io
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
import onnxruntime
|
||||
import onnxscript
|
||||
|
||||
from torch._subclasses import fake_tensor
|
||||
|
||||
|
|
@ -61,219 +44,6 @@ class ONNXFakeContext:
|
|||
"""List of paths of files that contain the model :meth:`state_dict`"""
|
||||
|
||||
|
||||
@deprecated(
|
||||
"torch.onnx.dynamo_export is deprecated since 2.7.0. Please use torch.onnx.export(..., dynamo=True) instead.",
|
||||
)
|
||||
class OnnxRegistry:
|
||||
"""Registry for ONNX functions.
|
||||
|
||||
.. deprecated:: 2.7
|
||||
Please use ``torch.onnx.export(..., dynamo=True)`` instead.
|
||||
|
||||
The registry maintains a mapping from qualified names to symbolic functions under a
|
||||
fixed opset version. It supports registering custom onnx-script functions and for
|
||||
dispatcher to dispatch calls to the appropriate function.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initializes the registry"""
|
||||
|
||||
# NOTE: _registry is the registry maps OpNameto a list of ONNXFunctions. It is important
|
||||
# not to directly modify this variable. Instead, access to it should be done through
|
||||
# the public methods: register_custom_op, get_ops, and is_registered_op.
|
||||
self._registry: dict[registration.OpName, list[registration.ONNXFunction]] = (
|
||||
defaultdict(list)
|
||||
)
|
||||
|
||||
self._opset_version = _constants.TORCHLIB_OPSET
|
||||
warnings.warn(
|
||||
f"torch.onnx.dynamo_export only implements opset version {self._opset_version} for now. If you need to use a "
|
||||
"different opset version, please register them with register_custom_op."
|
||||
)
|
||||
|
||||
self._initiate_registry_from_torchlib()
|
||||
|
||||
@property
|
||||
def opset_version(self) -> int:
|
||||
"""The ONNX opset version the exporter should target."""
|
||||
|
||||
return self._opset_version
|
||||
|
||||
def _initiate_registry_from_torchlib(self) -> None:
|
||||
"""Populates the registry with ATen functions from torchlib.
|
||||
|
||||
Args:
|
||||
torchlib_registry: The torchlib registry to use for populating the registry.
|
||||
"""
|
||||
for meta in onnxscript_apis.get_torchlib_ops():
|
||||
internal_name_instance = registration.OpName.from_qualified_name(
|
||||
meta.qualified_name
|
||||
)
|
||||
symbolic_function = registration.ONNXFunction(
|
||||
onnx_function=meta.function, # type: ignore[arg-type]
|
||||
op_full_name=internal_name_instance.qualified_name(),
|
||||
is_custom=False,
|
||||
is_complex=meta.is_complex,
|
||||
)
|
||||
self._register(internal_name_instance, symbolic_function)
|
||||
|
||||
def _register(
|
||||
self,
|
||||
internal_qualified_name: registration.OpName,
|
||||
symbolic_function: registration.ONNXFunction,
|
||||
) -> None:
|
||||
"""Registers a ONNXFunction to an operator.
|
||||
|
||||
Args:
|
||||
internal_qualified_name: The qualified name of the operator to register: OpName.
|
||||
symbolic_function: The ONNXFunction to register.
|
||||
"""
|
||||
self._registry[internal_qualified_name].append(symbolic_function)
|
||||
|
||||
def register_op(
|
||||
self,
|
||||
function: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction,
|
||||
namespace: str,
|
||||
op_name: str,
|
||||
overload: str | None = None,
|
||||
is_complex: bool = False,
|
||||
) -> None:
|
||||
"""Registers a custom operator: torch.ops.<namespace>.<op_name>.<overload>.
|
||||
|
||||
Args:
|
||||
function: The onnx-sctip function to register.
|
||||
namespace: The namespace of the operator to register.
|
||||
op_name: The name of the operator to register.
|
||||
overload: The overload of the operator to register. If it's default overload,
|
||||
leave it to None.
|
||||
is_complex: Whether the function is a function that handles complex valued inputs.
|
||||
|
||||
Raises:
|
||||
ValueError: If the name is not in the form of 'namespace::op'.
|
||||
"""
|
||||
internal_name_instance = registration.OpName.from_name_parts(
|
||||
namespace=namespace, op_name=op_name, overload=overload
|
||||
)
|
||||
symbolic_function = registration.ONNXFunction(
|
||||
onnx_function=function,
|
||||
op_full_name=internal_name_instance.qualified_name(),
|
||||
is_custom=True,
|
||||
is_complex=is_complex,
|
||||
)
|
||||
self._register(internal_name_instance, symbolic_function)
|
||||
|
||||
def get_op_functions(
|
||||
self, namespace: str, op_name: str, overload: str | None = None
|
||||
) -> list[registration.ONNXFunction] | None:
|
||||
"""Returns a list of ONNXFunctions for the given op: torch.ops.<namespace>.<op_name>.<overload>.
|
||||
|
||||
The list is ordered by the time of registration. The custom operators should be
|
||||
in the second half of the list.
|
||||
|
||||
Args:
|
||||
namespace: The namespace of the operator to get.
|
||||
op_name: The name of the operator to get.
|
||||
overload: The overload of the operator to get. If it's default overload,
|
||||
leave it to None.
|
||||
Returns:
|
||||
A list of ONNXFunctions corresponding to the given name, or None if
|
||||
the name is not in the registry.
|
||||
"""
|
||||
internal_name_instance = registration.OpName.from_name_parts(
|
||||
namespace=namespace, op_name=op_name, overload=overload
|
||||
)
|
||||
return self._registry.get(internal_name_instance)
|
||||
|
||||
def is_registered_op(
|
||||
self, namespace: str, op_name: str, overload: str | None = None
|
||||
) -> bool:
|
||||
"""Returns whether the given op is registered: torch.ops.<namespace>.<op_name>.<overload>.
|
||||
|
||||
Args:
|
||||
namespace: The namespace of the operator to check.
|
||||
op_name: The name of the operator to check.
|
||||
overload: The overload of the operator to check. If it's default overload,
|
||||
leave it to None.
|
||||
|
||||
Returns:
|
||||
True if the given op is registered, otherwise False.
|
||||
"""
|
||||
functions = self.get_op_functions(
|
||||
namespace=namespace, op_name=op_name, overload=overload
|
||||
)
|
||||
return functions is not None
|
||||
|
||||
def _all_registered_ops(self) -> set[str]:
|
||||
"""Returns the set of all registered function names."""
|
||||
return {
|
||||
op_name_class.qualified_name() for op_name_class in self._registry.keys()
|
||||
}
|
||||
|
||||
|
||||
@deprecated(
|
||||
"torch.onnx.dynamo_export is deprecated since 2.7.0. Please use torch.onnx.export(..., dynamo=True) instead.",
|
||||
category=None,
|
||||
)
|
||||
class ExportOptions:
|
||||
"""Options to influence the TorchDynamo ONNX exporter.
|
||||
|
||||
.. deprecated:: 2.7
|
||||
Please use ``torch.onnx.export(..., dynamo=True)`` instead.
|
||||
|
||||
Attributes:
|
||||
dynamic_shapes: Shape information hint for input/output tensors.
|
||||
When ``None``, the exporter determines the most compatible setting.
|
||||
When ``True``, all input shapes are considered dynamic.
|
||||
When ``False``, all input shapes are considered static.
|
||||
fake_context: The fake context used for symbolic tracing.
|
||||
onnx_registry: The ONNX registry used to register ATen operators to ONNX functions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dynamic_shapes: bool | None = True,
|
||||
fake_context: ONNXFakeContext | None = None,
|
||||
onnx_registry: OnnxRegistry | None = None,
|
||||
):
|
||||
self.dynamic_shapes = dynamic_shapes
|
||||
self.fake_context = fake_context
|
||||
self.onnx_registry = onnx_registry
|
||||
|
||||
|
||||
@deprecated(
|
||||
"torch.onnx.dynamo_export is deprecated since 2.7.0. Please use torch.onnx.export(..., dynamo=True) instead.",
|
||||
category=None,
|
||||
)
|
||||
class ResolvedExportOptions(ExportOptions):
|
||||
"""Consolidates :class:`ExportOptions` with default values.
|
||||
All unspecified options from :class:`ExportOptions` are assigned a default value.
|
||||
This is an internal class and its API may be changed at any time without notice.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
from torch.onnx._internal.fx import (
|
||||
dynamo_graph_extractor,
|
||||
onnxfunction_dispatcher,
|
||||
)
|
||||
|
||||
self.dynamic_shapes: bool = True
|
||||
self.fx_tracer: dynamo_graph_extractor.DynamoExport = (
|
||||
dynamo_graph_extractor.DynamoExport()
|
||||
)
|
||||
self.fake_context = None
|
||||
self.onnx_registry: OnnxRegistry = OnnxRegistry()
|
||||
self.decomposition_table = (
|
||||
decomposition_table.create_onnx_friendly_decomposition_table( # type: ignore[assignment]
|
||||
self.onnx_registry
|
||||
)
|
||||
)
|
||||
self.onnxfunction_dispatcher = onnxfunction_dispatcher.OnnxFunctionDispatcher(
|
||||
self.onnx_registry,
|
||||
)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def enable_fake_mode():
|
||||
"""Enable fake mode for the duration of the context.
|
||||
|
|
@ -346,101 +116,3 @@ def enable_fake_mode():
|
|||
fake_context.state_dict_paths = tuple(
|
||||
patcher_context.paths,
|
||||
) # type: ignore[assignment]
|
||||
|
||||
|
||||
@deprecated(
|
||||
"torch.onnx.dynamo_export is deprecated since 2.7.0. Please use torch.onnx.export(..., dynamo=True) instead.",
|
||||
)
|
||||
class ONNXRuntimeOptions:
|
||||
"""Options to influence the execution of the ONNX model through ONNX Runtime.
|
||||
|
||||
.. deprecated:: 2.7
|
||||
Please use ``torch.onnx.export(..., dynamo=True)`` instead.
|
||||
|
||||
Attributes:
|
||||
session_options: ONNX Runtime session options.
|
||||
execution_providers: ONNX Runtime execution providers to use during model execution.
|
||||
execution_provider_options: ONNX Runtime execution provider options.
|
||||
"""
|
||||
|
||||
session_options: Sequence[onnxruntime.SessionOptions] | None = None
|
||||
"""ONNX Runtime session options."""
|
||||
|
||||
execution_providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None
|
||||
"""ONNX Runtime execution providers to use during model execution."""
|
||||
|
||||
execution_provider_options: Sequence[dict[Any, Any]] | None = None
|
||||
"""ONNX Runtime execution provider options."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
session_options: Sequence[onnxruntime.SessionOptions] | None = None,
|
||||
execution_providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None,
|
||||
execution_provider_options: Sequence[dict[Any, Any]] | None = None,
|
||||
):
|
||||
self.session_options = session_options
|
||||
self.execution_providers = execution_providers
|
||||
self.execution_provider_options = execution_provider_options
|
||||
|
||||
|
||||
class FXGraphExtractor(abc.ABC):
|
||||
"""Abstract interface for FX graph extractor engines.
|
||||
This class isolates FX extraction logic from the rest of the export logic.
|
||||
That allows a single ONNX exporter that can leverage different FX graphs."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@abc.abstractmethod
|
||||
def generate_fx(
|
||||
self,
|
||||
options: ResolvedExportOptions,
|
||||
model: torch.nn.Module | Callable,
|
||||
model_args: Sequence[Any],
|
||||
model_kwargs: Mapping[str, Any],
|
||||
) -> torch.fx.GraphModule:
|
||||
"""Analyzes user ``model`` and generates a FX graph.
|
||||
Args:
|
||||
options: The export options.
|
||||
model: The user model.
|
||||
model_args: The model's positional input arguments.
|
||||
model_kwargs: The model's keyword input arguments.
|
||||
Returns:
|
||||
The generated FX Graph.
|
||||
"""
|
||||
...
|
||||
|
||||
# TODO: Design the passes API
|
||||
@abc.abstractmethod
|
||||
def pre_export_passes(
|
||||
self,
|
||||
options: ResolvedExportOptions,
|
||||
original_model: torch.nn.Module | Callable,
|
||||
fx_module: torch.fx.GraphModule,
|
||||
fx_module_args: Sequence[Any],
|
||||
):
|
||||
"""Applies pre-export passes to the FX graph.
|
||||
|
||||
Pre-export passes are FX-to-FX graph transformations that make the graph
|
||||
more palatable for the FX-to-ONNX conversion.
|
||||
For example, it can be used to flatten model input/output, add explicit
|
||||
casts to the graph, replace/decompose operators, functionalize the graph, etc.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
def common_pre_export_passes(
|
||||
options: ResolvedExportOptions,
|
||||
original_model: torch.nn.Module | Callable,
|
||||
fx_module: torch.fx.GraphModule,
|
||||
fx_module_args: Sequence[Any],
|
||||
):
|
||||
# TODO: Import here to prevent circular dependency
|
||||
from torch.onnx._internal.fx import passes
|
||||
|
||||
# ONNX does not support concept of (implicit) type promotion.
|
||||
# Insert type casts explicitly where needed.
|
||||
module = passes.InsertTypePromotion(fx_module).run()
|
||||
|
||||
return module
|
||||
|
|
|
|||
|
|
@ -1,160 +0,0 @@
|
|||
# mypy: allow-untyped-defs
|
||||
# NOTE: This file is referenced by name at
|
||||
# /opt/pytorch/torch/_dynamo/eval_frame.py::DONT_WRAP_FILES.
|
||||
# introduced by https://github.com/pytorch/pytorch/pull/98894.
|
||||
# If this file is renamed, moved, etc please update the reference there!
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import inspect
|
||||
from typing import Any, Callable, TYPE_CHECKING
|
||||
|
||||
import torch._dynamo
|
||||
import torch.fx
|
||||
from torch.onnx._internal import _exporter_legacy
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
|
||||
class _PyTreeExtensionContext:
|
||||
"""Context manager to register PyTree extension."""
|
||||
|
||||
_extensions: dict[type, tuple[pytree.FlattenFunc, pytree.UnflattenFunc]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._extensions = {}
|
||||
# Register PyTree extension for HuggingFace model output.
|
||||
self._register_huggingface_model_output_extension()
|
||||
|
||||
def __enter__(self):
|
||||
for class_type, (flatten_func, unflatten_func) in self._extensions.items():
|
||||
pytree._private_register_pytree_node(
|
||||
class_type,
|
||||
flatten_func,
|
||||
unflatten_func,
|
||||
)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
for class_type in self._extensions:
|
||||
pytree.SUPPORTED_NODES.pop(class_type)
|
||||
|
||||
def register_pytree_node(
|
||||
self,
|
||||
class_type: type,
|
||||
flatten_func: pytree.FlattenFunc,
|
||||
unflatten_func: pytree.UnflattenFunc,
|
||||
):
|
||||
"""Register PyTree extension for a custom python type.
|
||||
|
||||
Args:
|
||||
class_type: The custom python type.
|
||||
flatten_func: The flatten function.
|
||||
unflatten_func: The unflatten function.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the custom python type is already registered.
|
||||
"""
|
||||
if class_type in pytree.SUPPORTED_NODES or class_type in self._extensions:
|
||||
# PyTree node already registered.
|
||||
# E.g., `huggingface/transformer` registers `ModelOutput` as PyTree node after
|
||||
# https://github.com/huggingface/transformers/pull/25358.
|
||||
return
|
||||
self._extensions[class_type] = (flatten_func, unflatten_func)
|
||||
|
||||
def _register_huggingface_model_output_extension(self):
|
||||
try:
|
||||
from transformers import modeling_outputs # type: ignore[import]
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
def model_output_flatten(
|
||||
output: modeling_outputs.ModelOutput,
|
||||
) -> tuple[list[Any], pytree.Context]:
|
||||
return list(output.values()), (type(output), list(output.keys()))
|
||||
|
||||
def model_output_unflatten(
|
||||
values: list[Any], context: pytree.Context
|
||||
) -> modeling_outputs.ModelOutput:
|
||||
output_type, keys = context
|
||||
return output_type(**dict(zip(keys, values)))
|
||||
|
||||
# All 'ModelOutput' subclasses are defined under module 'modeling_outputs'.
|
||||
named_model_output_classes = inspect.getmembers(
|
||||
modeling_outputs,
|
||||
lambda x: (
|
||||
inspect.isclass(x)
|
||||
and issubclass(x, modeling_outputs.ModelOutput)
|
||||
and x is not modeling_outputs.ModelOutput
|
||||
),
|
||||
)
|
||||
|
||||
for _, class_type in named_model_output_classes:
|
||||
self.register_pytree_node(
|
||||
class_type,
|
||||
model_output_flatten,
|
||||
model_output_unflatten, # type: ignore[arg-type ]
|
||||
)
|
||||
|
||||
|
||||
class DynamoExport(_exporter_legacy.FXGraphExtractor):
|
||||
"""Generates a FX GraphModule using torch.dynamo.export API
|
||||
Args:
|
||||
aten_graph: If True, exports a graph with ATen operators.
|
||||
If False, exports a graph with Python operators.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
aten_graph: bool | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.aten_graph = aten_graph or True
|
||||
|
||||
def generate_fx(
|
||||
self,
|
||||
options: _exporter_legacy.ResolvedExportOptions,
|
||||
model: torch.nn.Module | Callable,
|
||||
model_args: Sequence[Any],
|
||||
model_kwargs: Mapping[str, Any],
|
||||
) -> torch.fx.GraphModule:
|
||||
# `dynamo.export` does not recognize custom user defined classes as output type.
|
||||
# Apply wrapper to adapt the outputs back to `dynamo.export` compatible types,
|
||||
# i.e. :class:`torch.Tensor`.
|
||||
wrapped_model = model
|
||||
|
||||
# Translate callable to FX graph.
|
||||
#
|
||||
fake_mode = (
|
||||
options.fake_context.fake_mode
|
||||
if options.fake_context
|
||||
else contextlib.nullcontext()
|
||||
)
|
||||
fx_mode = "symbolic" if options.dynamic_shapes else "fake"
|
||||
with fake_mode: # type: ignore[attr-defined]
|
||||
graph_module, graph_guard = torch._dynamo.export(
|
||||
wrapped_model,
|
||||
tracing_mode=fx_mode,
|
||||
)(
|
||||
*model_args,
|
||||
**model_kwargs,
|
||||
)
|
||||
del graph_guard # Unused
|
||||
torch._dynamo.reset()
|
||||
|
||||
return self.pre_export_passes(options, model, graph_module, model_args) # type: ignore[return-value]
|
||||
|
||||
def pre_export_passes(
|
||||
self,
|
||||
options: _exporter_legacy.ResolvedExportOptions,
|
||||
original_model: torch.nn.Module | Callable,
|
||||
fx_module: torch.fx.GraphModule,
|
||||
fx_module_args: Sequence[Any],
|
||||
):
|
||||
return _exporter_legacy.common_pre_export_passes(
|
||||
options, original_model, fx_module, fx_module_args
|
||||
)
|
||||
|
|
@ -25,9 +25,6 @@ if TYPE_CHECKING:
|
|||
graph_building as onnxscript_graph_building,
|
||||
)
|
||||
|
||||
from torch.onnx._internal._exporter_legacy import OnnxRegistry
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -58,7 +55,7 @@ class OnnxFunctionDispatcher:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
onnx_registry: OnnxRegistry,
|
||||
onnx_registry,
|
||||
):
|
||||
"""Initialize the ONNX Function dispatcher.
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user