mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Refactor dispatcher and registry (#147396)
This PR sets up the registry to accept onnx decomp functions to be moved into PyTorch (https://github.com/pytorch/pytorch/issues/139301). The ops from onnx script are currently appended to the registry. When the ops are moved into PyTorch, the moved ops takes precedence because they appear first in the registry list. After the migration hooks for loading ops from onnx script will be removed. 1. Use a private field `_pt_onnx_signature` to store function signatures to avoid conflicts 2. Update the registry to record the signature in OnnxDecompMeta and update the dispatcher to leverage the data structure 3. Update registry to prepare for onnx op registration, and update the the onnx_impl decorator to support a no_compile option Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/147396 Approved by: https://github.com/titaiwangms
This commit is contained in:
parent
4f3c070b25
commit
279c7f262e
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Callable
|
||||
from typing import Any, Callable
|
||||
|
||||
from onnxscript import ir
|
||||
|
||||
|
|
@ -188,11 +188,11 @@ def _get_type_from_tensor(
|
|||
|
||||
|
||||
def _get_first_tensor_in_node_list(
|
||||
nodes: Sequence[torch.fx.Node | None],
|
||||
nodes: Sequence[torch.fx.Node | Any],
|
||||
) -> torch.Tensor | None:
|
||||
for node in nodes:
|
||||
if (
|
||||
node is not None
|
||||
isinstance(node, torch.fx.Node)
|
||||
and "val" in node.meta
|
||||
and isinstance(node.meta["val"], torch.Tensor)
|
||||
):
|
||||
|
|
@ -213,13 +213,13 @@ def _get_named_fx_node_args(node: torch.fx.Node) -> dict[str, torch.fx.node.Argu
|
|||
|
||||
def get_matching_overload(
|
||||
node: torch.fx.Node,
|
||||
overloads: Sequence[Callable],
|
||||
overloads: Sequence[_registration.OnnxDecompMeta],
|
||||
) -> tuple[Callable | None, str]:
|
||||
"""Get the overload that matches the node's arguments.
|
||||
|
||||
Args:
|
||||
node: The node to match.
|
||||
overloads: The overloads to match against.
|
||||
overloads: The OnnxDecompMeta with overloads and their signatures to match against.
|
||||
|
||||
Returns:
|
||||
A tuple containing the matched overload and a string describing the reason for failure or success.
|
||||
|
|
@ -230,7 +230,7 @@ def get_matching_overload(
|
|||
# now we assume all inputs are named.
|
||||
return overloads[
|
||||
0
|
||||
], "The node target does not have a schema. Return the first one."
|
||||
].onnx_function, "The node target does not have a schema. Return the first one."
|
||||
named_args = _get_named_fx_node_args(node)
|
||||
# FIXME: Handle when we don't know the names of the arguments
|
||||
schema_args: dict[str, torch.Argument] = {
|
||||
|
|
@ -241,10 +241,10 @@ def get_matching_overload(
|
|||
for overload in overloads:
|
||||
assigned_types: dict[str, ir.TypeProtocol] = {}
|
||||
fail_reason = ""
|
||||
if not hasattr(overload, "signature"):
|
||||
if overload.signature is None:
|
||||
# When an overload does not have a signature, we assume it is a custom op and should be matched
|
||||
return (
|
||||
overload,
|
||||
overload.onnx_function,
|
||||
"The overload does not have a signature. Assuming it is a custom op and matching it.",
|
||||
)
|
||||
for param in overload.signature:
|
||||
|
|
@ -266,7 +266,7 @@ def get_matching_overload(
|
|||
arg = schema_args[param.name].default_value
|
||||
elif param.has_default():
|
||||
# Provided in the ONNX op definition
|
||||
arg = param.default
|
||||
arg = param.default # type: ignore[assignment]
|
||||
else:
|
||||
fail_reason = "Parameter not provided"
|
||||
break
|
||||
|
|
@ -297,8 +297,10 @@ def get_matching_overload(
|
|||
if not _attribute_type_compatible_with_arg(param, arg): # type: ignore[arg-type]
|
||||
fail_reason = f"Attribute type not compatible with argument: param=`{param}`, arg=`{arg}`"
|
||||
break
|
||||
else:
|
||||
raise TypeError(f"Unknown parameter type: {type(param)}")
|
||||
if not fail_reason:
|
||||
return overload, "Successfully matched overload"
|
||||
return overload.onnx_function, "Successfully matched overload"
|
||||
else:
|
||||
failure_messages.append(
|
||||
f"- Failed to match overload `{overload}`: {fail_reason}"
|
||||
|
|
@ -357,7 +359,5 @@ def dispatch(
|
|||
"Fast path: Only one decomposition is defined",
|
||||
)
|
||||
|
||||
overload, message = get_matching_overload(
|
||||
node, [decomp.onnx_function for decomp in decomp_metas]
|
||||
)
|
||||
overload, message = get_matching_overload(node, decomp_metas)
|
||||
return overload, message
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@ import logging
|
|||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from torch.onnx._internal._lazy_import import onnxscript_apis, onnxscript_ir as ir
|
||||
from torch.onnx._internal._lazy_import import onnxscript_ir as ir
|
||||
from torch.onnx._internal.exporter import _constants
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -115,8 +116,7 @@ def _maybe_set_opset_version(
|
|||
# Already set
|
||||
return
|
||||
if domain == _ONNX_DOMAIN:
|
||||
# Set the default opset version for ONNX operators
|
||||
opset_imports[domain] = onnxscript_apis.torchlib_opset_version()
|
||||
opset_imports[domain] = _constants.TORCHLIB_OPSET
|
||||
return
|
||||
if version is None:
|
||||
# We don't know the opset version, so set it to 1
|
||||
|
|
|
|||
|
|
@ -33,22 +33,59 @@ TorchOp: TypeAlias = Union[torch._ops.OpOverload, types.BuiltinFunctionType, Cal
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@dataclasses.dataclass
|
||||
class OnnxDecompMeta:
|
||||
"""A wrapper of onnx-script function with additional metadata.
|
||||
|
||||
onnx_function: The onnx-script function from torchlib.
|
||||
fx_target: The PyTorch node callable target.
|
||||
signature: The ONNX signature of the function. When None, the signature is inferred.
|
||||
is_custom: Whether the function is a custom function.
|
||||
is_complex: Whether the function is a function that handles complex valued inputs.
|
||||
device: The device the function is registered to. If None, it is registered to all devices.
|
||||
skip_signature_inference: Whether to skip signature inference for the function.
|
||||
"""
|
||||
|
||||
onnx_function: Callable
|
||||
fx_target: TorchOp
|
||||
signature: _schemas.OpSignature | None
|
||||
is_custom: bool = False
|
||||
is_complex: bool = False
|
||||
device: Literal["cuda", "cpu"] | str | None = None # noqa: PYI051
|
||||
skip_signature_inference: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.signature is None and not self.skip_signature_inference:
|
||||
try:
|
||||
if isinstance(self.onnx_function, onnxscript.OnnxFunction):
|
||||
signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined]
|
||||
self.onnx_function,
|
||||
self.onnx_function.function_ir.domain,
|
||||
self.onnx_function.name,
|
||||
opset_version=self.onnx_function.opset.version,
|
||||
)
|
||||
else:
|
||||
signature = _schemas.OpSignature.from_function(
|
||||
self.onnx_function, "__traced", self.onnx_function.__name__
|
||||
)
|
||||
except Exception as e:
|
||||
# Log an warning if the op is custom. Raise exception for builtin ops.
|
||||
if not self.is_custom:
|
||||
raise
|
||||
else:
|
||||
# When the function is targeting an HOP, for example, it will accept
|
||||
# functions as arguments and fail to generate an ONNX signature.
|
||||
# In this case we set signature to None and dispatch to this function always.
|
||||
logger.warning(
|
||||
"Failed to infer the signature for function '%s' because '%s'"
|
||||
"All nodes targeting `%s` will be dispatched to this function",
|
||||
self.onnx_function,
|
||||
e,
|
||||
self.fx_target,
|
||||
)
|
||||
else:
|
||||
self.signature = signature
|
||||
self.onnx_function._pt_onnx_signature = signature # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _get_overload(qualified_name: str) -> torch._ops.OpOverload | None:
|
||||
|
|
@ -120,14 +157,15 @@ class ONNXRegistry:
|
|||
torchlib_registry: The torchlib registry to use for populating the registry.
|
||||
"""
|
||||
registry = cls()
|
||||
for meta in _torchlib_registry.get_torchlib_ops():
|
||||
registry._register(meta.fx_target, meta)
|
||||
|
||||
# TODO(justinchuby): Remove this once torchlib is migrated to PyTorch
|
||||
torchlib_ops = onnxscript_apis.get_torchlib_ops()
|
||||
|
||||
for meta in torchlib_ops:
|
||||
qualified_name = meta.qualified_name
|
||||
overload_func = meta.function
|
||||
domain = meta.domain
|
||||
name = meta.name
|
||||
for torchlib_meta in torchlib_ops:
|
||||
qualified_name = torchlib_meta.qualified_name
|
||||
overload_func = torchlib_meta.function
|
||||
try:
|
||||
# NOTE: This is heavily guarded with try-except because we don't want
|
||||
# to fail the entire registry population if one function fails.
|
||||
|
|
@ -135,42 +173,18 @@ class ONNXRegistry:
|
|||
if target is None:
|
||||
continue
|
||||
|
||||
if isinstance(overload_func, onnxscript.OnnxFunction):
|
||||
opset_version = overload_func.opset.version
|
||||
else:
|
||||
opset_version = 1
|
||||
|
||||
overload_func.signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined]
|
||||
overload_func,
|
||||
domain,
|
||||
name,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
onnx_decomposition = OnnxDecompMeta(
|
||||
meta = OnnxDecompMeta(
|
||||
onnx_function=overload_func,
|
||||
fx_target=target,
|
||||
signature=None,
|
||||
is_custom=False,
|
||||
is_complex=meta.is_complex,
|
||||
is_complex=torchlib_meta.is_complex,
|
||||
)
|
||||
registry._register(target, onnx_decomposition)
|
||||
registry._register(target, meta)
|
||||
except Exception:
|
||||
logger.exception("Failed to register '%s'. Skipped", qualified_name)
|
||||
continue
|
||||
|
||||
# Gather ops from the internal torchlib registry
|
||||
# TODO(justinchuby): Make this the main registry after torchlib is migrated to PyTorch
|
||||
# Trigger registration
|
||||
from torch.onnx._internal.exporter._torchlib import ops
|
||||
|
||||
del ops
|
||||
for target, implementations in _torchlib_registry.registry.items(): # type: ignore[assignment]
|
||||
for impl in implementations:
|
||||
onnx_decomposition = OnnxDecompMeta(
|
||||
onnx_function=impl,
|
||||
fx_target=target, # type: ignore[arg-type]
|
||||
)
|
||||
registry._register(target, onnx_decomposition) # type: ignore[arg-type]
|
||||
|
||||
return registry
|
||||
|
||||
def _register(
|
||||
|
|
@ -209,32 +223,23 @@ class ONNXRegistry:
|
|||
function: The onnx-script function to register.
|
||||
is_complex: Whether the function is a function that handles complex valued inputs.
|
||||
"""
|
||||
if not hasattr(function, "signature"):
|
||||
try:
|
||||
# TODO(justinchuby): Use the op_signature attribute when onnxscript is updated in CI
|
||||
if isinstance(function, onnxscript.OnnxFunction):
|
||||
function.signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined]
|
||||
function,
|
||||
function.function_ir.domain,
|
||||
function.name,
|
||||
opset_version=function.opset.version,
|
||||
)
|
||||
else:
|
||||
function.signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined]
|
||||
function, "__custom", function.__name__
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to infer the signature for function '%s'", function
|
||||
)
|
||||
if isinstance(target, torch._ops.OpOverloadPacket):
|
||||
raise TypeError(
|
||||
f"Target '{target}' should be provided as an OpOverload instead of an "
|
||||
"OpOverloadPacket. You can get the default overload with "
|
||||
"<op>.default"
|
||||
)
|
||||
|
||||
onnx_decomposition = OnnxDecompMeta(
|
||||
onnx_function=function,
|
||||
fx_target=target,
|
||||
is_custom=True,
|
||||
is_complex=is_complex,
|
||||
self._register(
|
||||
target,
|
||||
OnnxDecompMeta(
|
||||
onnx_function=function,
|
||||
fx_target=target,
|
||||
signature=None,
|
||||
is_custom=True,
|
||||
is_complex=is_complex,
|
||||
),
|
||||
)
|
||||
self._register(target, onnx_decomposition)
|
||||
|
||||
def get_decomps(self, target: TorchOp) -> list[OnnxDecompMeta]:
|
||||
"""Returns a list of OnnxDecompMeta for the given op: torch.ops.<namespace>.<op_name>.<overload>.
|
||||
|
|
|
|||
|
|
@ -1 +0,0 @@
|
|||
|
||||
|
|
@ -5,37 +5,86 @@
|
|||
from __future__ import annotations
|
||||
|
||||
|
||||
__all__ = ["registry", "onnx_impl"]
|
||||
__all__ = ["onnx_impl", "get_torchlib_ops"]
|
||||
|
||||
import collections
|
||||
from typing import Callable, TypeVar
|
||||
import logging
|
||||
from typing import Any, Callable, Sequence, TypeVar
|
||||
|
||||
import onnxscript
|
||||
|
||||
import torch
|
||||
from torch.onnx._internal.exporter import _constants, _registration
|
||||
|
||||
|
||||
_T = TypeVar("_T", bound=Callable)
|
||||
|
||||
|
||||
class Registry(collections.UserDict[Callable, list[Callable]]):
|
||||
"""Registry for aten functions."""
|
||||
|
||||
def register(self, target: Callable, impl: Callable) -> None:
|
||||
"""Register a function."""
|
||||
|
||||
self.data.setdefault(target, []).append(impl)
|
||||
logger = logging.getLogger("__name__")
|
||||
|
||||
|
||||
# Default registry
|
||||
registry = Registry()
|
||||
_registry: list[_registration.OnnxDecompMeta] = []
|
||||
|
||||
|
||||
def onnx_impl(
|
||||
target: Callable,
|
||||
target: _registration.TorchOp | tuple[_registration.TorchOp, ...],
|
||||
*,
|
||||
trace_only: bool = False,
|
||||
complex: bool = False,
|
||||
no_compile: bool = False,
|
||||
private: bool = False,
|
||||
) -> Callable[[_T], _T]:
|
||||
"""Register an ONNX implementation of a torch op."""
|
||||
|
||||
if isinstance(target, torch._ops.OpOverloadPacket):
|
||||
raise TypeError(
|
||||
f"Target '{target}' should be provided as an OpOverload instead of an "
|
||||
"OpOverloadPacket. You can get the default overload with "
|
||||
"<op>.default"
|
||||
)
|
||||
|
||||
def wrapper(
|
||||
func: _T,
|
||||
) -> _T:
|
||||
registry.register(target, func)
|
||||
return func
|
||||
processed_func: Any
|
||||
if no_compile:
|
||||
processed_func = func
|
||||
else:
|
||||
torchlib_opset = onnxscript.values.Opset(
|
||||
domain=_constants.TORCHLIB_DOMAIN, version=1
|
||||
)
|
||||
|
||||
if not trace_only:
|
||||
# Compile the function
|
||||
processed_func = onnxscript.script(opset=torchlib_opset)(func)
|
||||
else:
|
||||
processed_func = onnxscript.TracedOnnxFunction(torchlib_opset, func)
|
||||
|
||||
if not private:
|
||||
# TODO(justinchuby): Simplify the logic and remove the private attribute
|
||||
# Skip registration if private
|
||||
if not isinstance(target, Sequence):
|
||||
targets = (target,)
|
||||
else:
|
||||
targets = target # type: ignore[assignment]
|
||||
|
||||
for t in targets:
|
||||
_registry.append(
|
||||
_registration.OnnxDecompMeta(
|
||||
onnx_function=processed_func,
|
||||
fx_target=t,
|
||||
signature=None,
|
||||
is_complex=complex,
|
||||
skip_signature_inference=no_compile,
|
||||
)
|
||||
)
|
||||
return processed_func # type: ignore[return-value]
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_torchlib_ops() -> tuple[_registration.OnnxDecompMeta, ...]:
|
||||
# Trigger op registration
|
||||
from torch.onnx._internal.exporter._torchlib import ops
|
||||
|
||||
del ops
|
||||
assert len(_registry) != 0
|
||||
return tuple(_registry)
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ def call_op(
|
|||
return node.outputs
|
||||
|
||||
|
||||
@onnx_impl(torch.ops.higher_order.cond)
|
||||
@onnx_impl(torch.ops.higher_order.cond, no_compile=True)
|
||||
def higher_order_cond(
|
||||
cond: ir.Value,
|
||||
true_func: ir.Function,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user