diff --git a/docs/source/onnx.rst b/docs/source/onnx.rst index 4f53e570466..db569fe7a1c 100644 --- a/docs/source/onnx.rst +++ b/docs/source/onnx.rst @@ -88,6 +88,7 @@ also be interested in reading our `development wiki None: - # TODO: Consider using meta["tensor_meta"] for this? Would it be faster? if isinstance(meta_val, tuple): logger.warning("Setting shape and type of tensors is not supported yet") if isinstance(meta_val, torch.Tensor): - # FIXME: Consider shape for complex values dims = [] for dim in meta_val.shape: if isinstance(dim, int): dims.append(dim) else: dims.append(str(dim.node)) - value.dtype = _torch_dtype_to_onnx_dtype(meta_val.dtype) - if complex_to_float: - if meta_val.dtype == torch.complex64: - value.dtype = ir.DataType.FLOAT - # Add 2 as the last dimension if the tensor is complex to hold the real/imag parts - dims.append(2) - elif meta_val.dtype == torch.complex128: - value.dtype = ir.DataType.DOUBLE - # Add 2 as the last dimension if the tensor is complex to hold the real/imag parts - dims.append(2) + + # If the dtype is set already (e.g. by the onnx_symbolic ops), + # we don't need to set it again. + # + # When a user specifies complex in onnx_symbolic, we consider that to + # be the intention even though non of the ONNX ops deals with complex values. + # In this case, we don't change the dtype or the shape of the tensor. + if value.dtype is None: + value.dtype = _torch_dtype_to_onnx_dtype(meta_val.dtype) + if complex_to_float: + if meta_val.dtype == torch.complex64: + value.dtype = ir.DataType.FLOAT + # Add 2 as the last dimension if the tensor is complex to hold the real/imag parts + dims.append(2) + elif meta_val.dtype == torch.complex128: + value.dtype = ir.DataType.DOUBLE + # Add 2 as the last dimension if the tensor is complex to hold the real/imag parts + dims.append(2) value.shape = ir.Shape(dims) elif isinstance(meta_val, (int, torch.SymInt)): diff --git a/torch/onnx/_internal/exporter/_torchlib/ops/__init__.py b/torch/onnx/_internal/exporter/_torchlib/ops/__init__.py index 88f569708bf..d07768f252b 100644 --- a/torch/onnx/_internal/exporter/_torchlib/ops/__init__.py +++ b/torch/onnx/_internal/exporter/_torchlib/ops/__init__.py @@ -1,6 +1,6 @@ from __future__ import annotations -__all__ = ["core", "hop"] +__all__ = ["core", "hop", "symbolic"] -from torch.onnx._internal.exporter._torchlib.ops import core, hop +from torch.onnx._internal.exporter._torchlib.ops import core, hop, symbolic diff --git a/torch/onnx/_internal/exporter/_torchlib/ops/symbolic.py b/torch/onnx/_internal/exporter/_torchlib/ops/symbolic.py new file mode 100644 index 00000000000..3a30d30cc4f --- /dev/null +++ b/torch/onnx/_internal/exporter/_torchlib/ops/symbolic.py @@ -0,0 +1,149 @@ +"""Implementation for higher-order operators.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from onnxscript.ir import convenience as ir_convenience + +import torch +from torch.onnx._internal._lazy_import import onnxscript_ir as ir +from torch.onnx._internal.exporter import _core +from torch.onnx._internal.exporter._torchlib._torchlib_registry import onnx_impl +from torch.onnx.ops import _symbolic_impl + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +def _call_symbolic_op( + op_type: str, + domain: str, + args: Sequence[ir.Value | None], + kwargs: dict[str, int | float | str | bool | list[int] | list[float] | list[str]], + dtypes: Sequence[int], + version: int | None, + metadata_props: dict[str, str] | None, +) -> Sequence[ir.Value]: + """Call an operator with the given arguments and keyword arguments. + + Arguments are always inputs, while keyword arguments are attributes. + """ + # This is a wrapper around the IR node creation that hooks into the _builder.OpRecorder + # tracer so that all nodes created are recorded the same way as if we were to use + # onnxscript ops directly. + + assert _core.current_tracer is not None + tracer = _core.current_tracer + + inputs = list(args) + + # If final inputs are None, strip them from the node inputs + for input in reversed(inputs): + if input is not None: + break + inputs.pop() + + # Construct and filter out None attributes + attributes = [ + attr + for attr in ir_convenience.convert_attributes(kwargs) # type: ignore[arg-type] + if attr.value is not None # type: ignore[union-attr] + ] + tracer.nodes.append( + node := ir.Node( + domain, + op_type, + inputs=inputs, + attributes=attributes, + num_outputs=len(dtypes), + version=version, + metadata_props=metadata_props, + ) + ) + # Set the dtypes for the outputs. We set them here because the graph builder + # Uses PyTorch types which are sometimes inaccurate when they are ONNX only + # types like float4e2m1. + for value, dtype in zip(node.outputs, dtypes): + value.dtype = ir.DataType(dtype) + # The shape is set by the graph builder. We don't need to set it here. + return node.outputs + + +@onnx_impl(torch.ops.onnx_symbolic._symbolic.default, no_compile=True) +def onnx_symbolic_symbolic( + inputs: Sequence[ir.Value | None], + op_type: str, + onnx_dtype: int, + *, + shape: Sequence[int | ir.Value], + attr_keys: Sequence[str], + attr_types: Sequence[str], + attr_pos: Sequence[tuple[int, int]], + attr_ints: Sequence[int], + attr_floats: Sequence[float], + attr_strs: Sequence[str], + metadata_props_keys: Sequence[str] = (), + metadata_props_values: Sequence[str] = (), + domain: str = "", + version: int | None = None, +) -> ir.Value: + del shape # Unused. The shapes are set by the graph builder + encoded = _symbolic_impl.EncodedAttrs( + attr_keys=list(attr_keys), + attr_types=list(attr_types), + attr_pos=list(attr_pos), + attr_ints=list(attr_ints), + attr_floats=list(attr_floats), + attr_strs=list(attr_strs), + ) + attrs = encoded.to_dict() + return _call_symbolic_op( + op_type, + domain, + inputs, + attrs, + dtypes=[onnx_dtype], + version=version, + metadata_props=dict(zip(metadata_props_keys, metadata_props_values)), + )[0] + + +@onnx_impl(torch.ops.onnx_symbolic._symbolic_multi_out.default, no_compile=True) +def onnx_symbolic_symbolic_multi_out( + inputs: Sequence[ir.Value | None], + op_type: str, + onnx_dtypes: Sequence[int], + *, + shapes: Sequence[Sequence[int | ir.Value]], + attr_keys: Sequence[str], + attr_types: Sequence[str], + attr_pos: Sequence[tuple[int, int]], + attr_ints: Sequence[int], + attr_floats: Sequence[float], + attr_strs: Sequence[str], + metadata_props_keys: Sequence[str] = (), + metadata_props_values: Sequence[str] = (), + domain: str = "", + version: int | None = None, +) -> Sequence[ir.Value]: + del shapes # Unused. The shapes are set by the graph builder + encoded = _symbolic_impl.EncodedAttrs( + attr_keys=list(attr_keys), + attr_types=list(attr_types), + attr_pos=list(attr_pos), + attr_ints=list(attr_ints), + attr_floats=list(attr_floats), + attr_strs=list(attr_strs), + ) + attrs = encoded.to_dict() + return _call_symbolic_op( + op_type, + domain, + inputs, + attrs, + dtypes=onnx_dtypes, + version=version, + metadata_props=dict(zip(metadata_props_keys, metadata_props_values)), + ) diff --git a/torch/onnx/ops/__init__.py b/torch/onnx/ops/__init__.py new file mode 100644 index 00000000000..e22a03d6c83 --- /dev/null +++ b/torch/onnx/ops/__init__.py @@ -0,0 +1,243 @@ +"""ONNX operators as native torch.fx operators. + +This module provides a set of functions to create ONNX operators in the FX graph +which are exportable to ONNX. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from torch.onnx.ops import _symbolic_impl + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +# https://github.com/onnx/onnx/blob/f542e1f06699ea7e1db5f62af53355b64338c723/onnx/onnx.proto#L597 +_TORCH_DTYPE_TO_ONNX_DTYPE = { + torch.float32: 1, # FLOAT + torch.uint8: 2, # UINT8 + torch.int8: 3, # INT8 + torch.uint16: 4, # UINT16 + torch.int16: 5, # INT16 + torch.int32: 6, # INT32 + torch.int64: 7, # INT64 + str: 8, # STRING + torch.bool: 9, # BOOL + torch.float16: 10, # FLOAT16 + torch.double: 11, # DOUBLE + torch.uint32: 12, # UINT32 + torch.uint64: 13, # UINT64 + torch.complex64: 14, # COMPLEX64 + torch.complex128: 15, # COMPLEX128 + torch.bfloat16: 16, # BFLOAT16 + torch.float8_e4m3fn: 17, # FLOAT8E4M3FN + torch.float8_e4m3fnuz: 18, # FLOAT8E4M3FNUZ + torch.float8_e5m2: 19, # FLOAT8E5M2 + torch.float8_e5m2fnuz: 20, # FLOAT8E5M2FNUZ +} + + +def _parse_domain_op_type(domain_op: str) -> tuple[str, str]: + splitted = domain_op.split("::", 1) + if len(splitted) == 1: + domain = "" + op_type = splitted[0] + else: + domain = splitted[0] + op_type = splitted[1] + return domain, op_type + + +def symbolic( + domain_op: str, + /, + inputs: Sequence[torch.Tensor], + attrs: dict[ + str, + int + | float + | str + | bool + | Sequence[int] + | Sequence[float] + | Sequence[str] + | Sequence[bool], + ] + | None = None, + *, + dtype: torch.dtype | int, + shape: Sequence[int | torch.SymInt], + version: int | None = None, + metadata_props: dict[str, str] | None = None, +) -> torch.Tensor: + """Create a symbolic FX operator to represent an arbitrary ONNX operator. + + This function is used to create a symbolic operator with a single output. + To create an operator with multiple outputs, use :func:`symbolic_multi_out`. + + Example:: + + class CustomOp(torch.nn.Module): + def forward(self, x: torch.Tensor): + return torch.onnx.ops.symbolic( + "custom_domain::CustomOp", + (x,), + dict(attr_key="attr_value"), + dtype=x.dtype, + shape=x.shape, + version=1, + ) + # This will create a symbolic ONNX operator with the name "CustomOp" in the "custom_domain" domain. + # The output tensor will have the specified dtype and shape. + + + # You may then export this model to ONNX using torch.onnx.export. + + Args: + domain_op: The domain and operator name, separated by "::". For example, + "custom_domain::CustomOp". + inputs: The input tensors to the operator. + attrs: The attributes of the operator. The keys are attribute names and + the values are attribute values. Valid attribute types are int, float, + str, bool, and lists of int, float, str, and bool. Tensor attributes + are unsupported. + dtype: The data type of the output tensor.This can be either a torch.dtype + or an integer representing the ONNX data type. + shape: The shape of the output tensor. This can be a list of integers or + SymInt values. + version: The version of the opset used for the operator. + metadata_props: Metadata properties for the ONNX node. + This is a dictionary of str-str pairs. + + Returns: + The output tensor of the operator. + """ + if not isinstance(dtype, int): + torch._check( + dtype in _TORCH_DTYPE_TO_ONNX_DTYPE, lambda: f"Unsupported dtype: {dtype}" + ) + dtype = _TORCH_DTYPE_TO_ONNX_DTYPE[dtype] + domain, op_type = _parse_domain_op_type(domain_op) + if attrs is None: + attrs = {} + encoded_attrs = _symbolic_impl.EncodedAttrs.from_dict(attrs) + # TODO: Parse domain + return _symbolic_impl._symbolic( + inputs, + op_type, + dtype, + shape=shape, + attr_keys=encoded_attrs.attr_keys, + attr_types=encoded_attrs.attr_types, + attr_pos=encoded_attrs.attr_pos, + attr_ints=encoded_attrs.attr_ints, + attr_floats=encoded_attrs.attr_floats, + attr_strs=encoded_attrs.attr_strs, + metadata_props_keys=metadata_props.keys() if metadata_props else [], + metadata_props_values=metadata_props.values() if metadata_props else [], + domain=domain, + version=version, + ) + + +def symbolic_multi_out( + domain_op: str, + /, + inputs: Sequence[torch.Tensor], + attrs: dict[ + str, + int + | float + | str + | bool + | Sequence[int] + | Sequence[float] + | Sequence[str] + | Sequence[bool], + ] + | None = None, + *, + dtypes: Sequence[torch.dtype | int], + shapes: Sequence[Sequence[int | torch.SymInt]], + version: int | None = None, + metadata_props: dict[str, str] | None = None, +) -> Sequence[torch.Tensor]: + """Create a symbolic FX operator to represent an arbitrary ONNX operator with multiple outputs. + + Example:: + + class CustomOp(torch.nn.Module): + def forward(self, x: torch.Tensor): + return torch.onnx.ops.symbolic( + "custom_domain::CustomOp", + (x,), + dict(attr_key="attr_value"), + dtypes=(x.dtype, torch.float32), + shapes=(x.shape, [1, 2, 3]), + version=1, + ) + # This will create a symbolic ONNX operator with the name "CustomOp" in the "custom_domain" domain. + # The output tensor will have the specified dtype and shape. + + + # You may then export this model to ONNX using torch.onnx.export. + + Args: + domain_op: The domain and operator name, separated by "::". For example, + "custom_domain::CustomOp". + inputs: The input tensors to the operator. + attrs: The attributes of the operator. The keys are attribute names and + the values are attribute values. Valid attribute types are int, float, + str, bool, and lists of int, float, str, and bool. Tensor attributes + are unsupported. + dtypes: The data types of the output tensors. This can be a list of + torch.dtype or integers representing the ONNX data types. The length + of this list must be the number of outputs. + shapes: The shapes of the output tensors. This can be a list of lists of + integers or SymInt values. The length of this list must be the number of outputs. + version: The version of the opset used for the operator. + metadata_props: Metadata properties for the ONNX node. + This is a dictionary of str-str pairs. + + Returns: + A list of output tensors of the operator. + """ + torch._check( + len(shapes) == len(dtypes), + lambda: f"Number of shapes ({len(shapes)}) must match number of dtypes ({len(dtypes)})", + ) + onnx_dtypes = [] + for dtype in dtypes: + if not isinstance(dtype, int): + torch._check( + dtype in _TORCH_DTYPE_TO_ONNX_DTYPE, + lambda: f"Unsupported dtype: {dtype}", + ) + onnx_dtypes.append(_TORCH_DTYPE_TO_ONNX_DTYPE[dtype]) + else: + onnx_dtypes.append(dtype) + domain, op_type = _parse_domain_op_type(domain_op) + if attrs is None: + attrs = {} + encoded_attrs = _symbolic_impl.EncodedAttrs.from_dict(attrs) + # Use the size of dtypes to determine the number of outputs + return _symbolic_impl._symbolic_multi_out( + inputs, + op_type, + onnx_dtypes, + shapes=shapes, + attr_keys=encoded_attrs.attr_keys, + attr_types=encoded_attrs.attr_types, + attr_pos=encoded_attrs.attr_pos, + attr_ints=encoded_attrs.attr_ints, + attr_floats=encoded_attrs.attr_floats, + attr_strs=encoded_attrs.attr_strs, + metadata_props_keys=metadata_props.keys() if metadata_props else [], + metadata_props_values=metadata_props.values() if metadata_props else [], + domain=domain, + version=version, + ) diff --git a/torch/onnx/ops/_symbolic_impl.py b/torch/onnx/ops/_symbolic_impl.py new file mode 100644 index 00000000000..7dd1370720a --- /dev/null +++ b/torch/onnx/ops/_symbolic_impl.py @@ -0,0 +1,330 @@ +"""Implementation of symbolic FX ops to represent arbitrary ONNX ops. + +This module provides a way to create symbolic FX operators that can represent +arbitrary ONNX operators. + +The operators are called "symbolic" because they don't do any actual computation +but instead serve as placeholders in the computation graph. + +Each implementation contains two parts: A "real" implementation that produce all +zeros based on the input shape and dtype, and a "fake" implementation that does more +or less the same thing but is required by the `torch.library.custom_op` interface. +""" + +import dataclasses +from collections.abc import Sequence +from typing import Optional, Union + +import torch + + +_ONNX_DTYPE_TO_TORCH_DTYPE: dict[int, torch.dtype] = { + 1: torch.float32, # FLOAT + 2: torch.uint8, # UINT8 + 3: torch.int8, # INT8 + 4: torch.uint16, # UINT16 + 5: torch.int16, # INT16 + 6: torch.int32, # INT32 + 7: torch.int64, # INT64 + 9: torch.bool, # BOOL + 10: torch.float16, # FLOAT16 + 11: torch.double, # DOUBLE + 12: torch.uint32, # UINT32 + 13: torch.uint64, # UINT64 + 14: torch.complex64, # COMPLEX64 + 15: torch.complex128, # COMPLEX128 + 16: torch.bfloat16, # BFLOAT16 + 17: torch.float8_e4m3fn, # FLOAT8E4M3FN + 18: torch.float8_e4m3fnuz, # FLOAT8E4M3FNUZ + 19: torch.float8_e5m2, # FLOAT8E5M2 + 20: torch.float8_e5m2fnuz, # FLOAT8E5M2FNUZ + 21: torch.uint8, # UINT4 + 22: torch.uint8, # INT4 + 23: torch.uint8, # FLOAT4E2M1 +} + +_INT_TYPE = "i" +_FLOAT_TYPE = "f" +_STRING_TYPE = "s" +_INT_SEQ_TYPE = "is" +_FLOAT_SEQ_TYPE = "fs" +_STRING_SEQ_TYPE = "ss" + + +@dataclasses.dataclass +class EncodedAttrs: + """Class to encode attributes from dictionary into lists of FX compatible attributes. + + Since FX does not support dictionaries, we need to encode the attributes into + lists. This class provides a way to encode and decode the attributes. + + Attributes: + attr_keys: List of attribute keys. + attr_types: List of attribute types. Values can be "i" (int), "f" (float), + "s" (string), "is" (int sequence), "fs" (float sequence), or "ss" (string sequence). + attr_pos: List of tuples representing the start and end positions of each + attribute in the corresponding list. + attr_ints: List of integer attributes. + attr_floats: List of float attributes. + attr_strs: List of string attributes. + """ + + attr_keys: list[str] + attr_types: list[str] + attr_pos: list[tuple[int, int]] + attr_ints: list[int] + attr_floats: list[float] + attr_strs: list[str] + + @classmethod + def from_dict( + cls, + attrs: dict[ + str, + Union[ + int, + float, + str, + bool, + Sequence[int], + Sequence[float], + Sequence[str], + Sequence[bool], + ], + ], + ) -> "EncodedAttrs": + encoded = cls( + attr_keys=[], + attr_types=[], + attr_pos=[], + attr_ints=[], + attr_floats=[], + attr_strs=[], + ) + for i, (k, v) in enumerate(attrs.items()): + encoded.attr_keys.append(k) + if isinstance(v, int): + start_pos = len(encoded.attr_ints) + encoded.attr_ints.append(v) + encoded.attr_pos.append((start_pos, start_pos + 1)) + encoded.attr_types.append(_INT_TYPE) + elif isinstance(v, float): + start_pos = len(encoded.attr_floats) + encoded.attr_floats.append(v) + encoded.attr_pos.append((start_pos, start_pos + 1)) + encoded.attr_types.append(_FLOAT_TYPE) + elif isinstance(v, str): + start_pos = len(encoded.attr_strs) + encoded.attr_strs.append(v) + encoded.attr_pos.append((start_pos, start_pos + 1)) + encoded.attr_types.append(_STRING_TYPE) + elif isinstance(v, Sequence): + if len(v) == 0: + raise ValueError(f"Empty sequence for attribute {k}") + if any(isinstance(elem, float) for elem in v): + start_pos = len(encoded.attr_floats) + encoded.attr_floats.extend([float(elem) for elem in v]) + encoded.attr_pos.append((start_pos, start_pos + len(v))) + encoded.attr_types.append(_FLOAT_SEQ_TYPE) + elif isinstance(v[0], int): + start_pos = len(encoded.attr_ints) + encoded.attr_ints.extend([int(elem) for elem in v]) + encoded.attr_pos.append((start_pos, start_pos + len(v))) + encoded.attr_types.append(_INT_SEQ_TYPE) + elif isinstance(v[0], str): + start_pos = len(encoded.attr_strs) + encoded.attr_strs.extend([str(elem) for elem in v]) + encoded.attr_pos.append((start_pos, start_pos + len(v))) + encoded.attr_types.append(_STRING_SEQ_TYPE) + else: + raise ValueError(f"Unsupported sequence type for attribute {k}") + else: + raise ValueError(f"Unsupported attribute type for {k}: {type(v)}") + assert len(encoded.attr_keys) == len(encoded.attr_types), ( + f"Mismatch between number of attribute keys and types: {len(encoded.attr_keys)} != {len(encoded.attr_types)}" + ) + assert len(encoded.attr_keys) == len(encoded.attr_pos), ( + f"Mismatch between number of attribute keys and positions: {len(encoded.attr_keys)} != {len(encoded.attr_pos)}" + ) + return encoded + + def to_dict( + self, + ) -> dict[ + str, + Union[ + int, + float, + str, + list[int], + list[float], + list[str], + ], + ]: + """Convert the encoded attributes back to a dictionary for creating an ONNX node.""" + attrs: dict[ + str, + Union[ + int, + float, + str, + list[int], + list[float], + list[str], + ], + ] = {} + for i, key in enumerate(self.attr_keys): + attr_type = self.attr_types[i] + if attr_type == _INT_TYPE: + attrs[key] = self.attr_ints[self.attr_pos[i][0]] + elif attr_type == _FLOAT_TYPE: + attrs[key] = self.attr_floats[self.attr_pos[i][0]] + elif attr_type == _STRING_TYPE: + attrs[key] = self.attr_strs[self.attr_pos[i][0]] + elif attr_type == _FLOAT_SEQ_TYPE: + attrs[key] = self.attr_floats[self.attr_pos[i][0] : self.attr_pos[i][1]] + elif attr_type == _INT_SEQ_TYPE: + attrs[key] = self.attr_ints[self.attr_pos[i][0] : self.attr_pos[i][1]] + elif attr_type == _STRING_SEQ_TYPE: + attrs[key] = self.attr_strs[self.attr_pos[i][0] : self.attr_pos[i][1]] + else: + raise ValueError(f"Unsupported attribute type: {attr_type}") + return attrs + + +@torch.library.custom_op( + "onnx_symbolic::_symbolic", + mutates_args=(), + schema=( + "(Tensor?[] inputs, str op_type, int onnx_dtype, *," + " SymInt[] shape, str[] attr_keys, str[] attr_types, int[][] attr_pos," + " int[] attr_ints, float[] attr_floats, str[] attr_strs, str[] metadata_props_keys," + " str[] metadata_props_values, str domain='', int? version=None" + ") -> Tensor" + ), +) +def _symbolic( + inputs: Sequence[Optional[torch.Tensor]], + op_type: str, + onnx_dtype: int, + *, + shape: Sequence[Union[int, torch.SymInt]], + attr_keys: Sequence[str], + attr_types: Sequence[str], + attr_pos: Sequence[tuple[int, int]], + attr_ints: Sequence[int], + attr_floats: Sequence[float], + attr_strs: Sequence[str], + metadata_props_keys: Sequence[str] = (), + metadata_props_values: Sequence[str] = (), + domain: str = "", + version: Optional[int] = None, +) -> torch.Tensor: + torch._check( + onnx_dtype in _ONNX_DTYPE_TO_TORCH_DTYPE, + lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_ONNX_DTYPE_TO_TORCH_DTYPE.keys())}", + ) + return torch.zeros(shape, dtype=_ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype]) + + +@_symbolic.register_fake +def _( + inputs: Sequence[torch.Tensor], + op_type: str, + onnx_dtype: int, + *, + shape: Sequence[Union[int, torch.SymInt]], + attr_keys: Sequence[str], + attr_types: Sequence[str], + attr_pos: Sequence[tuple[int, int]], + attr_ints: Sequence[int], + attr_floats: Sequence[float], + attr_strs: Sequence[str], + metadata_props_keys: Sequence[str] = (), + metadata_props_values: Sequence[str] = (), + domain: str = "", + version: Optional[int] = None, +) -> torch.Tensor: + torch._check( + onnx_dtype in _ONNX_DTYPE_TO_TORCH_DTYPE, + lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_ONNX_DTYPE_TO_TORCH_DTYPE.keys())}", + ) + # NOTE(justinchuby): Use zeros instead of torch.empty because I haven't figured + # out how it can handle empty shapes + return torch.zeros(shape, dtype=_ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype]) + + +@torch.library.custom_op( + "onnx_symbolic::_symbolic_multi_out", + mutates_args=(), + schema=( + "(Tensor?[] inputs, str op_type, int[] onnx_dtypes, *," + " SymInt[][] shapes, str[] attr_keys, str[] attr_types, int[][] attr_pos," + " int[] attr_ints, float[] attr_floats, str[] attr_strs, str[] metadata_props_keys," + " str[] metadata_props_values, str domain='', int? version=None" + ") -> Tensor[]" + ), +) +def _symbolic_multi_out( + inputs: Sequence[Optional[torch.Tensor]], + op_type: str, + onnx_dtypes: Sequence[int], + *, + shapes: Sequence[Sequence[Union[int, torch.SymInt]]], + attr_keys: Sequence[str], + attr_types: Sequence[str], + attr_pos: Sequence[tuple[int, int]], + attr_ints: Sequence[int], + attr_floats: Sequence[float], + attr_strs: Sequence[str], + metadata_props_keys: Sequence[str] = (), + metadata_props_values: Sequence[str] = (), + domain: str = "", + version: Optional[int] = None, +) -> list[torch.Tensor]: + outputs = [] + torch._check( + len(shapes) == len(onnx_dtypes), + lambda: f"Number of shapes ({len(shapes)}) must match number of ONNX dtypes ({len(onnx_dtypes)})", + ) + for shape, onnx_dtype in zip(shapes, onnx_dtypes): + torch._check( + onnx_dtype in _ONNX_DTYPE_TO_TORCH_DTYPE, + lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_ONNX_DTYPE_TO_TORCH_DTYPE.keys())}", + ) + outputs.append(torch.zeros(shape, dtype=_ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype])) + return outputs + + +@_symbolic_multi_out.register_fake +def _( + inputs: Sequence[torch.Tensor], + op_type: str, + onnx_dtypes: Sequence[int], + *, + shapes: Sequence[Sequence[Union[int, torch.SymInt]]], + attr_keys: Sequence[str], + attr_types: Sequence[str], + attr_pos: Sequence[tuple[int, int]], + attr_ints: Sequence[int], + attr_floats: Sequence[float], + attr_strs: Sequence[str], + metadata_props_keys: Sequence[str] = (), + metadata_props_values: Sequence[str] = (), + domain: str = "", + version: Optional[int] = None, +) -> list[torch.Tensor]: + outputs = [] + torch._check( + len(shapes) == len(onnx_dtypes), + lambda: f"Number of shapes ({len(shapes)}) must match number of ONNX dtypes ({len(onnx_dtypes)})", + ) + for shape, onnx_dtype in zip(shapes, onnx_dtypes): + torch._check( + onnx_dtype in _ONNX_DTYPE_TO_TORCH_DTYPE, + lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_ONNX_DTYPE_TO_TORCH_DTYPE.keys())}", + ) + # NOTE(justinchuby): Use zeros instead of torch.empty because I haven't figured + # out how it can handle empty shapes + outputs.append(torch.zeros(shape, dtype=_ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype])) + return outputs