[ONNX] Refactor dispatcher and registry (#147396)

This PR sets up the registry to accept onnx decomp functions to be moved into PyTorch (https://github.com/pytorch/pytorch/issues/139301).

The ops from onnx script are currently appended to the registry. When the ops are moved into PyTorch, the moved ops takes precedence because they appear first in the registry list.

After the migration hooks for loading ops from onnx script will be removed.

1. Use a private field `_pt_onnx_signature` to store function signatures to avoid conflicts
2. Update the registry to record the signature in OnnxDecompMeta and update the dispatcher to leverage the data structure
3. Update registry to prepare for onnx op registration, and update the the onnx_impl decorator to support a no_compile option

Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147396
Approved by: https://github.com/titaiwangms
This commit is contained in:
Justin Chu 2025-02-19 11:23:01 -08:00 committed by PyTorch MergeBot
parent 4f3c070b25
commit 279c7f262e
6 changed files with 145 additions and 92 deletions

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import logging
from collections.abc import Sequence
from typing import Callable
from typing import Any, Callable
from onnxscript import ir
@ -188,11 +188,11 @@ def _get_type_from_tensor(
def _get_first_tensor_in_node_list(
nodes: Sequence[torch.fx.Node | None],
nodes: Sequence[torch.fx.Node | Any],
) -> torch.Tensor | None:
for node in nodes:
if (
node is not None
isinstance(node, torch.fx.Node)
and "val" in node.meta
and isinstance(node.meta["val"], torch.Tensor)
):
@ -213,13 +213,13 @@ def _get_named_fx_node_args(node: torch.fx.Node) -> dict[str, torch.fx.node.Argu
def get_matching_overload(
node: torch.fx.Node,
overloads: Sequence[Callable],
overloads: Sequence[_registration.OnnxDecompMeta],
) -> tuple[Callable | None, str]:
"""Get the overload that matches the node's arguments.
Args:
node: The node to match.
overloads: The overloads to match against.
overloads: The OnnxDecompMeta with overloads and their signatures to match against.
Returns:
A tuple containing the matched overload and a string describing the reason for failure or success.
@ -230,7 +230,7 @@ def get_matching_overload(
# now we assume all inputs are named.
return overloads[
0
], "The node target does not have a schema. Return the first one."
].onnx_function, "The node target does not have a schema. Return the first one."
named_args = _get_named_fx_node_args(node)
# FIXME: Handle when we don't know the names of the arguments
schema_args: dict[str, torch.Argument] = {
@ -241,10 +241,10 @@ def get_matching_overload(
for overload in overloads:
assigned_types: dict[str, ir.TypeProtocol] = {}
fail_reason = ""
if not hasattr(overload, "signature"):
if overload.signature is None:
# When an overload does not have a signature, we assume it is a custom op and should be matched
return (
overload,
overload.onnx_function,
"The overload does not have a signature. Assuming it is a custom op and matching it.",
)
for param in overload.signature:
@ -266,7 +266,7 @@ def get_matching_overload(
arg = schema_args[param.name].default_value
elif param.has_default():
# Provided in the ONNX op definition
arg = param.default
arg = param.default # type: ignore[assignment]
else:
fail_reason = "Parameter not provided"
break
@ -297,8 +297,10 @@ def get_matching_overload(
if not _attribute_type_compatible_with_arg(param, arg): # type: ignore[arg-type]
fail_reason = f"Attribute type not compatible with argument: param=`{param}`, arg=`{arg}`"
break
else:
raise TypeError(f"Unknown parameter type: {type(param)}")
if not fail_reason:
return overload, "Successfully matched overload"
return overload.onnx_function, "Successfully matched overload"
else:
failure_messages.append(
f"- Failed to match overload `{overload}`: {fail_reason}"
@ -357,7 +359,5 @@ def dispatch(
"Fast path: Only one decomposition is defined",
)
overload, message = get_matching_overload(
node, [decomp.onnx_function for decomp in decomp_metas]
)
overload, message = get_matching_overload(node, decomp_metas)
return overload, message

View File

@ -5,7 +5,8 @@ import logging
import re
from typing import TYPE_CHECKING
from torch.onnx._internal._lazy_import import onnxscript_apis, onnxscript_ir as ir
from torch.onnx._internal._lazy_import import onnxscript_ir as ir
from torch.onnx._internal.exporter import _constants
if TYPE_CHECKING:
@ -115,8 +116,7 @@ def _maybe_set_opset_version(
# Already set
return
if domain == _ONNX_DOMAIN:
# Set the default opset version for ONNX operators
opset_imports[domain] = onnxscript_apis.torchlib_opset_version()
opset_imports[domain] = _constants.TORCHLIB_OPSET
return
if version is None:
# We don't know the opset version, so set it to 1

View File

@ -33,22 +33,59 @@ TorchOp: TypeAlias = Union[torch._ops.OpOverload, types.BuiltinFunctionType, Cal
logger = logging.getLogger(__name__)
@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass
class OnnxDecompMeta:
"""A wrapper of onnx-script function with additional metadata.
onnx_function: The onnx-script function from torchlib.
fx_target: The PyTorch node callable target.
signature: The ONNX signature of the function. When None, the signature is inferred.
is_custom: Whether the function is a custom function.
is_complex: Whether the function is a function that handles complex valued inputs.
device: The device the function is registered to. If None, it is registered to all devices.
skip_signature_inference: Whether to skip signature inference for the function.
"""
onnx_function: Callable
fx_target: TorchOp
signature: _schemas.OpSignature | None
is_custom: bool = False
is_complex: bool = False
device: Literal["cuda", "cpu"] | str | None = None # noqa: PYI051
skip_signature_inference: bool = False
def __post_init__(self) -> None:
if self.signature is None and not self.skip_signature_inference:
try:
if isinstance(self.onnx_function, onnxscript.OnnxFunction):
signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined]
self.onnx_function,
self.onnx_function.function_ir.domain,
self.onnx_function.name,
opset_version=self.onnx_function.opset.version,
)
else:
signature = _schemas.OpSignature.from_function(
self.onnx_function, "__traced", self.onnx_function.__name__
)
except Exception as e:
# Log an warning if the op is custom. Raise exception for builtin ops.
if not self.is_custom:
raise
else:
# When the function is targeting an HOP, for example, it will accept
# functions as arguments and fail to generate an ONNX signature.
# In this case we set signature to None and dispatch to this function always.
logger.warning(
"Failed to infer the signature for function '%s' because '%s'"
"All nodes targeting `%s` will be dispatched to this function",
self.onnx_function,
e,
self.fx_target,
)
else:
self.signature = signature
self.onnx_function._pt_onnx_signature = signature # type: ignore[attr-defined]
def _get_overload(qualified_name: str) -> torch._ops.OpOverload | None:
@ -120,14 +157,15 @@ class ONNXRegistry:
torchlib_registry: The torchlib registry to use for populating the registry.
"""
registry = cls()
for meta in _torchlib_registry.get_torchlib_ops():
registry._register(meta.fx_target, meta)
# TODO(justinchuby): Remove this once torchlib is migrated to PyTorch
torchlib_ops = onnxscript_apis.get_torchlib_ops()
for meta in torchlib_ops:
qualified_name = meta.qualified_name
overload_func = meta.function
domain = meta.domain
name = meta.name
for torchlib_meta in torchlib_ops:
qualified_name = torchlib_meta.qualified_name
overload_func = torchlib_meta.function
try:
# NOTE: This is heavily guarded with try-except because we don't want
# to fail the entire registry population if one function fails.
@ -135,42 +173,18 @@ class ONNXRegistry:
if target is None:
continue
if isinstance(overload_func, onnxscript.OnnxFunction):
opset_version = overload_func.opset.version
else:
opset_version = 1
overload_func.signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined]
overload_func,
domain,
name,
opset_version=opset_version,
)
onnx_decomposition = OnnxDecompMeta(
meta = OnnxDecompMeta(
onnx_function=overload_func,
fx_target=target,
signature=None,
is_custom=False,
is_complex=meta.is_complex,
is_complex=torchlib_meta.is_complex,
)
registry._register(target, onnx_decomposition)
registry._register(target, meta)
except Exception:
logger.exception("Failed to register '%s'. Skipped", qualified_name)
continue
# Gather ops from the internal torchlib registry
# TODO(justinchuby): Make this the main registry after torchlib is migrated to PyTorch
# Trigger registration
from torch.onnx._internal.exporter._torchlib import ops
del ops
for target, implementations in _torchlib_registry.registry.items(): # type: ignore[assignment]
for impl in implementations:
onnx_decomposition = OnnxDecompMeta(
onnx_function=impl,
fx_target=target, # type: ignore[arg-type]
)
registry._register(target, onnx_decomposition) # type: ignore[arg-type]
return registry
def _register(
@ -209,32 +223,23 @@ class ONNXRegistry:
function: The onnx-script function to register.
is_complex: Whether the function is a function that handles complex valued inputs.
"""
if not hasattr(function, "signature"):
try:
# TODO(justinchuby): Use the op_signature attribute when onnxscript is updated in CI
if isinstance(function, onnxscript.OnnxFunction):
function.signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined]
function,
function.function_ir.domain,
function.name,
opset_version=function.opset.version,
)
else:
function.signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined]
function, "__custom", function.__name__
)
except Exception:
logger.exception(
"Failed to infer the signature for function '%s'", function
)
if isinstance(target, torch._ops.OpOverloadPacket):
raise TypeError(
f"Target '{target}' should be provided as an OpOverload instead of an "
"OpOverloadPacket. You can get the default overload with "
"<op>.default"
)
onnx_decomposition = OnnxDecompMeta(
onnx_function=function,
fx_target=target,
is_custom=True,
is_complex=is_complex,
self._register(
target,
OnnxDecompMeta(
onnx_function=function,
fx_target=target,
signature=None,
is_custom=True,
is_complex=is_complex,
),
)
self._register(target, onnx_decomposition)
def get_decomps(self, target: TorchOp) -> list[OnnxDecompMeta]:
"""Returns a list of OnnxDecompMeta for the given op: torch.ops.<namespace>.<op_name>.<overload>.

View File

@ -5,37 +5,86 @@
from __future__ import annotations
__all__ = ["registry", "onnx_impl"]
__all__ = ["onnx_impl", "get_torchlib_ops"]
import collections
from typing import Callable, TypeVar
import logging
from typing import Any, Callable, Sequence, TypeVar
import onnxscript
import torch
from torch.onnx._internal.exporter import _constants, _registration
_T = TypeVar("_T", bound=Callable)
class Registry(collections.UserDict[Callable, list[Callable]]):
"""Registry for aten functions."""
def register(self, target: Callable, impl: Callable) -> None:
"""Register a function."""
self.data.setdefault(target, []).append(impl)
logger = logging.getLogger("__name__")
# Default registry
registry = Registry()
_registry: list[_registration.OnnxDecompMeta] = []
def onnx_impl(
target: Callable,
target: _registration.TorchOp | tuple[_registration.TorchOp, ...],
*,
trace_only: bool = False,
complex: bool = False,
no_compile: bool = False,
private: bool = False,
) -> Callable[[_T], _T]:
"""Register an ONNX implementation of a torch op."""
if isinstance(target, torch._ops.OpOverloadPacket):
raise TypeError(
f"Target '{target}' should be provided as an OpOverload instead of an "
"OpOverloadPacket. You can get the default overload with "
"<op>.default"
)
def wrapper(
func: _T,
) -> _T:
registry.register(target, func)
return func
processed_func: Any
if no_compile:
processed_func = func
else:
torchlib_opset = onnxscript.values.Opset(
domain=_constants.TORCHLIB_DOMAIN, version=1
)
if not trace_only:
# Compile the function
processed_func = onnxscript.script(opset=torchlib_opset)(func)
else:
processed_func = onnxscript.TracedOnnxFunction(torchlib_opset, func)
if not private:
# TODO(justinchuby): Simplify the logic and remove the private attribute
# Skip registration if private
if not isinstance(target, Sequence):
targets = (target,)
else:
targets = target # type: ignore[assignment]
for t in targets:
_registry.append(
_registration.OnnxDecompMeta(
onnx_function=processed_func,
fx_target=t,
signature=None,
is_complex=complex,
skip_signature_inference=no_compile,
)
)
return processed_func # type: ignore[return-value]
return wrapper
def get_torchlib_ops() -> tuple[_registration.OnnxDecompMeta, ...]:
# Trigger op registration
from torch.onnx._internal.exporter._torchlib import ops
del ops
assert len(_registry) != 0
return tuple(_registry)

View File

@ -60,7 +60,7 @@ def call_op(
return node.outputs
@onnx_impl(torch.ops.higher_order.cond)
@onnx_impl(torch.ops.higher_order.cond, no_compile=True)
def higher_order_cond(
cond: ir.Value,
true_func: ir.Function,