[ONNX] Remove legacy dynamo graph extractor (#158262)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158262
Approved by: https://github.com/justinchuby
ghstack dependencies: #158258
This commit is contained in:
Ti-Tai Wang 2025-07-15 17:32:59 +00:00 committed by PyTorch MergeBot
parent 19625daf88
commit 205241a0d5
3 changed files with 3 additions and 494 deletions

View File

@ -3,31 +3,18 @@ from __future__ import annotations
__all__ = [
"ExportOptions",
"ONNXRuntimeOptions",
"OnnxRegistry",
"enable_fake_mode",
]
import abc
import contextlib
import dataclasses
import logging
import warnings
from collections import defaultdict
from typing import Any, Callable, TYPE_CHECKING
from typing_extensions import deprecated
from typing import Any, TYPE_CHECKING
import torch
import torch._ops
from torch.onnx._internal._lazy_import import onnxscript_apis
from torch.onnx._internal.exporter import _constants
from torch.onnx._internal.fx import (
decomposition_table,
patcher as patcher,
registration,
)
from torch.onnx._internal.fx import patcher as patcher
# We can only import onnx from this module in a type-checking context to ensure that
@ -35,10 +22,6 @@ from torch.onnx._internal.fx import (
# 'import onnx' inside of dynamo_export (by way of _assert_dependencies).
if TYPE_CHECKING:
import io
from collections.abc import Mapping, Sequence
import onnxruntime
import onnxscript
from torch._subclasses import fake_tensor
@ -61,219 +44,6 @@ class ONNXFakeContext:
"""List of paths of files that contain the model :meth:`state_dict`"""
@deprecated(
"torch.onnx.dynamo_export is deprecated since 2.7.0. Please use torch.onnx.export(..., dynamo=True) instead.",
)
class OnnxRegistry:
"""Registry for ONNX functions.
.. deprecated:: 2.7
Please use ``torch.onnx.export(..., dynamo=True)`` instead.
The registry maintains a mapping from qualified names to symbolic functions under a
fixed opset version. It supports registering custom onnx-script functions and for
dispatcher to dispatch calls to the appropriate function.
"""
def __init__(self) -> None:
"""Initializes the registry"""
# NOTE: _registry is the registry maps OpNameto a list of ONNXFunctions. It is important
# not to directly modify this variable. Instead, access to it should be done through
# the public methods: register_custom_op, get_ops, and is_registered_op.
self._registry: dict[registration.OpName, list[registration.ONNXFunction]] = (
defaultdict(list)
)
self._opset_version = _constants.TORCHLIB_OPSET
warnings.warn(
f"torch.onnx.dynamo_export only implements opset version {self._opset_version} for now. If you need to use a "
"different opset version, please register them with register_custom_op."
)
self._initiate_registry_from_torchlib()
@property
def opset_version(self) -> int:
"""The ONNX opset version the exporter should target."""
return self._opset_version
def _initiate_registry_from_torchlib(self) -> None:
"""Populates the registry with ATen functions from torchlib.
Args:
torchlib_registry: The torchlib registry to use for populating the registry.
"""
for meta in onnxscript_apis.get_torchlib_ops():
internal_name_instance = registration.OpName.from_qualified_name(
meta.qualified_name
)
symbolic_function = registration.ONNXFunction(
onnx_function=meta.function, # type: ignore[arg-type]
op_full_name=internal_name_instance.qualified_name(),
is_custom=False,
is_complex=meta.is_complex,
)
self._register(internal_name_instance, symbolic_function)
def _register(
self,
internal_qualified_name: registration.OpName,
symbolic_function: registration.ONNXFunction,
) -> None:
"""Registers a ONNXFunction to an operator.
Args:
internal_qualified_name: The qualified name of the operator to register: OpName.
symbolic_function: The ONNXFunction to register.
"""
self._registry[internal_qualified_name].append(symbolic_function)
def register_op(
self,
function: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction,
namespace: str,
op_name: str,
overload: str | None = None,
is_complex: bool = False,
) -> None:
"""Registers a custom operator: torch.ops.<namespace>.<op_name>.<overload>.
Args:
function: The onnx-sctip function to register.
namespace: The namespace of the operator to register.
op_name: The name of the operator to register.
overload: The overload of the operator to register. If it's default overload,
leave it to None.
is_complex: Whether the function is a function that handles complex valued inputs.
Raises:
ValueError: If the name is not in the form of 'namespace::op'.
"""
internal_name_instance = registration.OpName.from_name_parts(
namespace=namespace, op_name=op_name, overload=overload
)
symbolic_function = registration.ONNXFunction(
onnx_function=function,
op_full_name=internal_name_instance.qualified_name(),
is_custom=True,
is_complex=is_complex,
)
self._register(internal_name_instance, symbolic_function)
def get_op_functions(
self, namespace: str, op_name: str, overload: str | None = None
) -> list[registration.ONNXFunction] | None:
"""Returns a list of ONNXFunctions for the given op: torch.ops.<namespace>.<op_name>.<overload>.
The list is ordered by the time of registration. The custom operators should be
in the second half of the list.
Args:
namespace: The namespace of the operator to get.
op_name: The name of the operator to get.
overload: The overload of the operator to get. If it's default overload,
leave it to None.
Returns:
A list of ONNXFunctions corresponding to the given name, or None if
the name is not in the registry.
"""
internal_name_instance = registration.OpName.from_name_parts(
namespace=namespace, op_name=op_name, overload=overload
)
return self._registry.get(internal_name_instance)
def is_registered_op(
self, namespace: str, op_name: str, overload: str | None = None
) -> bool:
"""Returns whether the given op is registered: torch.ops.<namespace>.<op_name>.<overload>.
Args:
namespace: The namespace of the operator to check.
op_name: The name of the operator to check.
overload: The overload of the operator to check. If it's default overload,
leave it to None.
Returns:
True if the given op is registered, otherwise False.
"""
functions = self.get_op_functions(
namespace=namespace, op_name=op_name, overload=overload
)
return functions is not None
def _all_registered_ops(self) -> set[str]:
"""Returns the set of all registered function names."""
return {
op_name_class.qualified_name() for op_name_class in self._registry.keys()
}
@deprecated(
"torch.onnx.dynamo_export is deprecated since 2.7.0. Please use torch.onnx.export(..., dynamo=True) instead.",
category=None,
)
class ExportOptions:
"""Options to influence the TorchDynamo ONNX exporter.
.. deprecated:: 2.7
Please use ``torch.onnx.export(..., dynamo=True)`` instead.
Attributes:
dynamic_shapes: Shape information hint for input/output tensors.
When ``None``, the exporter determines the most compatible setting.
When ``True``, all input shapes are considered dynamic.
When ``False``, all input shapes are considered static.
fake_context: The fake context used for symbolic tracing.
onnx_registry: The ONNX registry used to register ATen operators to ONNX functions.
"""
def __init__(
self,
*,
dynamic_shapes: bool | None = True,
fake_context: ONNXFakeContext | None = None,
onnx_registry: OnnxRegistry | None = None,
):
self.dynamic_shapes = dynamic_shapes
self.fake_context = fake_context
self.onnx_registry = onnx_registry
@deprecated(
"torch.onnx.dynamo_export is deprecated since 2.7.0. Please use torch.onnx.export(..., dynamo=True) instead.",
category=None,
)
class ResolvedExportOptions(ExportOptions):
"""Consolidates :class:`ExportOptions` with default values.
All unspecified options from :class:`ExportOptions` are assigned a default value.
This is an internal class and its API may be changed at any time without notice.
"""
def __init__(self):
from torch.onnx._internal.fx import (
dynamo_graph_extractor,
onnxfunction_dispatcher,
)
self.dynamic_shapes: bool = True
self.fx_tracer: dynamo_graph_extractor.DynamoExport = (
dynamo_graph_extractor.DynamoExport()
)
self.fake_context = None
self.onnx_registry: OnnxRegistry = OnnxRegistry()
self.decomposition_table = (
decomposition_table.create_onnx_friendly_decomposition_table( # type: ignore[assignment]
self.onnx_registry
)
)
self.onnxfunction_dispatcher = onnxfunction_dispatcher.OnnxFunctionDispatcher(
self.onnx_registry,
)
@contextlib.contextmanager
def enable_fake_mode():
"""Enable fake mode for the duration of the context.
@ -346,101 +116,3 @@ def enable_fake_mode():
fake_context.state_dict_paths = tuple(
patcher_context.paths,
) # type: ignore[assignment]
@deprecated(
"torch.onnx.dynamo_export is deprecated since 2.7.0. Please use torch.onnx.export(..., dynamo=True) instead.",
)
class ONNXRuntimeOptions:
"""Options to influence the execution of the ONNX model through ONNX Runtime.
.. deprecated:: 2.7
Please use ``torch.onnx.export(..., dynamo=True)`` instead.
Attributes:
session_options: ONNX Runtime session options.
execution_providers: ONNX Runtime execution providers to use during model execution.
execution_provider_options: ONNX Runtime execution provider options.
"""
session_options: Sequence[onnxruntime.SessionOptions] | None = None
"""ONNX Runtime session options."""
execution_providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None
"""ONNX Runtime execution providers to use during model execution."""
execution_provider_options: Sequence[dict[Any, Any]] | None = None
"""ONNX Runtime execution provider options."""
def __init__(
self,
*,
session_options: Sequence[onnxruntime.SessionOptions] | None = None,
execution_providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None,
execution_provider_options: Sequence[dict[Any, Any]] | None = None,
):
self.session_options = session_options
self.execution_providers = execution_providers
self.execution_provider_options = execution_provider_options
class FXGraphExtractor(abc.ABC):
"""Abstract interface for FX graph extractor engines.
This class isolates FX extraction logic from the rest of the export logic.
That allows a single ONNX exporter that can leverage different FX graphs."""
def __init__(self) -> None:
super().__init__()
@abc.abstractmethod
def generate_fx(
self,
options: ResolvedExportOptions,
model: torch.nn.Module | Callable,
model_args: Sequence[Any],
model_kwargs: Mapping[str, Any],
) -> torch.fx.GraphModule:
"""Analyzes user ``model`` and generates a FX graph.
Args:
options: The export options.
model: The user model.
model_args: The model's positional input arguments.
model_kwargs: The model's keyword input arguments.
Returns:
The generated FX Graph.
"""
...
# TODO: Design the passes API
@abc.abstractmethod
def pre_export_passes(
self,
options: ResolvedExportOptions,
original_model: torch.nn.Module | Callable,
fx_module: torch.fx.GraphModule,
fx_module_args: Sequence[Any],
):
"""Applies pre-export passes to the FX graph.
Pre-export passes are FX-to-FX graph transformations that make the graph
more palatable for the FX-to-ONNX conversion.
For example, it can be used to flatten model input/output, add explicit
casts to the graph, replace/decompose operators, functionalize the graph, etc.
"""
...
def common_pre_export_passes(
options: ResolvedExportOptions,
original_model: torch.nn.Module | Callable,
fx_module: torch.fx.GraphModule,
fx_module_args: Sequence[Any],
):
# TODO: Import here to prevent circular dependency
from torch.onnx._internal.fx import passes
# ONNX does not support concept of (implicit) type promotion.
# Insert type casts explicitly where needed.
module = passes.InsertTypePromotion(fx_module).run()
return module

View File

@ -1,160 +0,0 @@
# mypy: allow-untyped-defs
# NOTE: This file is referenced by name at
# /opt/pytorch/torch/_dynamo/eval_frame.py::DONT_WRAP_FILES.
# introduced by https://github.com/pytorch/pytorch/pull/98894.
# If this file is renamed, moved, etc please update the reference there!
from __future__ import annotations
import contextlib
import inspect
from typing import Any, Callable, TYPE_CHECKING
import torch._dynamo
import torch.fx
from torch.onnx._internal import _exporter_legacy
from torch.utils import _pytree as pytree
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
class _PyTreeExtensionContext:
"""Context manager to register PyTree extension."""
_extensions: dict[type, tuple[pytree.FlattenFunc, pytree.UnflattenFunc]]
def __init__(self) -> None:
self._extensions = {}
# Register PyTree extension for HuggingFace model output.
self._register_huggingface_model_output_extension()
def __enter__(self):
for class_type, (flatten_func, unflatten_func) in self._extensions.items():
pytree._private_register_pytree_node(
class_type,
flatten_func,
unflatten_func,
)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
for class_type in self._extensions:
pytree.SUPPORTED_NODES.pop(class_type)
def register_pytree_node(
self,
class_type: type,
flatten_func: pytree.FlattenFunc,
unflatten_func: pytree.UnflattenFunc,
):
"""Register PyTree extension for a custom python type.
Args:
class_type: The custom python type.
flatten_func: The flatten function.
unflatten_func: The unflatten function.
Raises:
AssertionError: If the custom python type is already registered.
"""
if class_type in pytree.SUPPORTED_NODES or class_type in self._extensions:
# PyTree node already registered.
# E.g., `huggingface/transformer` registers `ModelOutput` as PyTree node after
# https://github.com/huggingface/transformers/pull/25358.
return
self._extensions[class_type] = (flatten_func, unflatten_func)
def _register_huggingface_model_output_extension(self):
try:
from transformers import modeling_outputs # type: ignore[import]
except ImportError:
return
def model_output_flatten(
output: modeling_outputs.ModelOutput,
) -> tuple[list[Any], pytree.Context]:
return list(output.values()), (type(output), list(output.keys()))
def model_output_unflatten(
values: list[Any], context: pytree.Context
) -> modeling_outputs.ModelOutput:
output_type, keys = context
return output_type(**dict(zip(keys, values)))
# All 'ModelOutput' subclasses are defined under module 'modeling_outputs'.
named_model_output_classes = inspect.getmembers(
modeling_outputs,
lambda x: (
inspect.isclass(x)
and issubclass(x, modeling_outputs.ModelOutput)
and x is not modeling_outputs.ModelOutput
),
)
for _, class_type in named_model_output_classes:
self.register_pytree_node(
class_type,
model_output_flatten,
model_output_unflatten, # type: ignore[arg-type ]
)
class DynamoExport(_exporter_legacy.FXGraphExtractor):
"""Generates a FX GraphModule using torch.dynamo.export API
Args:
aten_graph: If True, exports a graph with ATen operators.
If False, exports a graph with Python operators.
"""
def __init__(
self,
aten_graph: bool | None = None,
):
super().__init__()
self.aten_graph = aten_graph or True
def generate_fx(
self,
options: _exporter_legacy.ResolvedExportOptions,
model: torch.nn.Module | Callable,
model_args: Sequence[Any],
model_kwargs: Mapping[str, Any],
) -> torch.fx.GraphModule:
# `dynamo.export` does not recognize custom user defined classes as output type.
# Apply wrapper to adapt the outputs back to `dynamo.export` compatible types,
# i.e. :class:`torch.Tensor`.
wrapped_model = model
# Translate callable to FX graph.
#
fake_mode = (
options.fake_context.fake_mode
if options.fake_context
else contextlib.nullcontext()
)
fx_mode = "symbolic" if options.dynamic_shapes else "fake"
with fake_mode: # type: ignore[attr-defined]
graph_module, graph_guard = torch._dynamo.export(
wrapped_model,
tracing_mode=fx_mode,
)(
*model_args,
**model_kwargs,
)
del graph_guard # Unused
torch._dynamo.reset()
return self.pre_export_passes(options, model, graph_module, model_args) # type: ignore[return-value]
def pre_export_passes(
self,
options: _exporter_legacy.ResolvedExportOptions,
original_model: torch.nn.Module | Callable,
fx_module: torch.fx.GraphModule,
fx_module_args: Sequence[Any],
):
return _exporter_legacy.common_pre_export_passes(
options, original_model, fx_module, fx_module_args
)

View File

@ -25,9 +25,6 @@ if TYPE_CHECKING:
graph_building as onnxscript_graph_building,
)
from torch.onnx._internal._exporter_legacy import OnnxRegistry
logger = logging.getLogger(__name__)
@ -58,7 +55,7 @@ class OnnxFunctionDispatcher:
def __init__(
self,
onnx_registry: OnnxRegistry,
onnx_registry,
):
"""Initialize the ONNX Function dispatcher.