"""Implementation of symbolic FX ops to represent arbitrary ONNX ops. This module provides a way to create symbolic FX operators that can represent arbitrary ONNX operators. The operators are called "symbolic" because they don't do any actual computation but instead serve as placeholders in the computation graph. Each implementation contains two parts: A "real" implementation that produce all zeros based on the input shape and dtype, and a "fake" implementation that does more or less the same thing but is required by the `torch.library.custom_op` interface. """ # flake8: noqa: B950 import dataclasses from collections.abc import Sequence from typing import Optional, Union import torch from torch.onnx.ops import _dtype_mappings _INT_TYPE = "i" _FLOAT_TYPE = "f" _STRING_TYPE = "s" _INT_SEQ_TYPE = "is" _FLOAT_SEQ_TYPE = "fs" _STRING_SEQ_TYPE = "ss" @dataclasses.dataclass class EncodedAttrs: """Class to encode attributes from dictionary into lists of FX compatible attributes. Since FX does not support dictionaries, we need to encode the attributes into lists. This class provides a way to encode and decode the attributes. Attributes: attr_keys: List of attribute keys. attr_types: List of attribute types. Values can be "i" (int), "f" (float), "s" (string), "is" (int sequence), "fs" (float sequence), or "ss" (string sequence). attr_pos: List of tuples representing the start and end positions of each attribute in the corresponding list. attr_ints: List of integer attributes. attr_floats: List of float attributes. attr_strs: List of string attributes. """ attr_keys: list[str] attr_types: list[str] attr_pos: list[tuple[int, int]] attr_ints: list[int] attr_floats: list[float] attr_strs: list[str] @classmethod def from_dict( cls, attrs: dict[ str, Union[ int, float, str, bool, Sequence[int], Sequence[float], Sequence[str], Sequence[bool], ], ], ) -> "EncodedAttrs": encoded = cls( attr_keys=[], attr_types=[], attr_pos=[], attr_ints=[], attr_floats=[], attr_strs=[], ) for i, (k, v) in enumerate(attrs.items()): encoded.attr_keys.append(k) if isinstance(v, int): start_pos = len(encoded.attr_ints) encoded.attr_ints.append(v) encoded.attr_pos.append((start_pos, start_pos + 1)) encoded.attr_types.append(_INT_TYPE) elif isinstance(v, float): start_pos = len(encoded.attr_floats) encoded.attr_floats.append(v) encoded.attr_pos.append((start_pos, start_pos + 1)) encoded.attr_types.append(_FLOAT_TYPE) elif isinstance(v, str): start_pos = len(encoded.attr_strs) encoded.attr_strs.append(v) encoded.attr_pos.append((start_pos, start_pos + 1)) encoded.attr_types.append(_STRING_TYPE) elif isinstance(v, Sequence): if len(v) == 0: raise ValueError(f"Empty sequence for attribute {k}") if any(isinstance(elem, float) for elem in v): start_pos = len(encoded.attr_floats) encoded.attr_floats.extend([float(elem) for elem in v]) encoded.attr_pos.append((start_pos, start_pos + len(v))) encoded.attr_types.append(_FLOAT_SEQ_TYPE) elif isinstance(v[0], int): start_pos = len(encoded.attr_ints) encoded.attr_ints.extend([int(elem) for elem in v]) encoded.attr_pos.append((start_pos, start_pos + len(v))) encoded.attr_types.append(_INT_SEQ_TYPE) elif isinstance(v[0], str): start_pos = len(encoded.attr_strs) encoded.attr_strs.extend([str(elem) for elem in v]) encoded.attr_pos.append((start_pos, start_pos + len(v))) encoded.attr_types.append(_STRING_SEQ_TYPE) else: raise ValueError(f"Unsupported sequence type for attribute {k}") else: raise ValueError(f"Unsupported attribute type for {k}: {type(v)}") assert len(encoded.attr_keys) == len(encoded.attr_types), ( f"Mismatch between number of attribute keys and types: {len(encoded.attr_keys)} != {len(encoded.attr_types)}" ) assert len(encoded.attr_keys) == len(encoded.attr_pos), ( f"Mismatch between number of attribute keys and positions: {len(encoded.attr_keys)} != {len(encoded.attr_pos)}" ) return encoded def to_dict( self, ) -> dict[ str, Union[ int, float, str, list[int], list[float], list[str], ], ]: """Convert the encoded attributes back to a dictionary for creating an ONNX node.""" attrs: dict[ str, Union[ int, float, str, list[int], list[float], list[str], ], ] = {} for i, key in enumerate(self.attr_keys): attr_type = self.attr_types[i] if attr_type == _INT_TYPE: attrs[key] = self.attr_ints[self.attr_pos[i][0]] elif attr_type == _FLOAT_TYPE: attrs[key] = self.attr_floats[self.attr_pos[i][0]] elif attr_type == _STRING_TYPE: attrs[key] = self.attr_strs[self.attr_pos[i][0]] elif attr_type == _FLOAT_SEQ_TYPE: attrs[key] = self.attr_floats[self.attr_pos[i][0] : self.attr_pos[i][1]] elif attr_type == _INT_SEQ_TYPE: attrs[key] = self.attr_ints[self.attr_pos[i][0] : self.attr_pos[i][1]] elif attr_type == _STRING_SEQ_TYPE: attrs[key] = self.attr_strs[self.attr_pos[i][0] : self.attr_pos[i][1]] else: raise ValueError(f"Unsupported attribute type: {attr_type}") return attrs @torch.library.custom_op( "onnx_symbolic::_symbolic", mutates_args=(), schema=( "(Tensor?[] inputs, str op_type, int onnx_dtype, *," " SymInt[] shape, str[] attr_keys, str[] attr_types, int[][] attr_pos," " int[] attr_ints, float[] attr_floats, str[] attr_strs, str[] metadata_props_keys," " str[] metadata_props_values, str domain='', int? version=None" ") -> Tensor" ), ) def _symbolic( inputs: Sequence[Optional[torch.Tensor]], op_type: str, onnx_dtype: int, *, shape: Sequence[Union[int, torch.SymInt]], attr_keys: Sequence[str], attr_types: Sequence[str], attr_pos: Sequence[tuple[int, int]], attr_ints: Sequence[int], attr_floats: Sequence[float], attr_strs: Sequence[str], metadata_props_keys: Sequence[str] = (), metadata_props_values: Sequence[str] = (), domain: str = "", version: Optional[int] = None, ) -> torch.Tensor: torch._check( onnx_dtype in _dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE, lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE.keys())}", ) return torch.zeros( shape, dtype=_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype] ) @_symbolic.register_fake def _( inputs: Sequence[torch.Tensor], op_type: str, onnx_dtype: int, *, shape: Sequence[Union[int, torch.SymInt]], attr_keys: Sequence[str], attr_types: Sequence[str], attr_pos: Sequence[tuple[int, int]], attr_ints: Sequence[int], attr_floats: Sequence[float], attr_strs: Sequence[str], metadata_props_keys: Sequence[str] = (), metadata_props_values: Sequence[str] = (), domain: str = "", version: Optional[int] = None, ) -> torch.Tensor: torch._check( onnx_dtype in _dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE, lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE.keys())}", ) # NOTE(justinchuby): Use zeros instead of torch.empty because I haven't figured # out how it can handle empty shapes return torch.zeros( shape, dtype=_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype] ) @torch.library.custom_op( "onnx_symbolic::_symbolic_multi_out", mutates_args=(), schema=( "(Tensor?[] inputs, str op_type, int[] onnx_dtypes, *," " SymInt[][] shapes, str[] attr_keys, str[] attr_types, int[][] attr_pos," " int[] attr_ints, float[] attr_floats, str[] attr_strs, str[] metadata_props_keys," " str[] metadata_props_values, str domain='', int? version=None" ") -> Tensor[]" ), ) def _symbolic_multi_out( inputs: Sequence[Optional[torch.Tensor]], op_type: str, onnx_dtypes: Sequence[int], *, shapes: Sequence[Sequence[Union[int, torch.SymInt]]], attr_keys: Sequence[str], attr_types: Sequence[str], attr_pos: Sequence[tuple[int, int]], attr_ints: Sequence[int], attr_floats: Sequence[float], attr_strs: Sequence[str], metadata_props_keys: Sequence[str] = (), metadata_props_values: Sequence[str] = (), domain: str = "", version: Optional[int] = None, ) -> list[torch.Tensor]: outputs = [] torch._check( len(shapes) == len(onnx_dtypes), lambda: f"Number of shapes ({len(shapes)}) must match number of ONNX dtypes ({len(onnx_dtypes)})", ) for shape, onnx_dtype in zip(shapes, onnx_dtypes): torch._check( onnx_dtype in _dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE, lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE.keys())}", ) outputs.append( torch.zeros( shape, dtype=_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype] ) ) return outputs @_symbolic_multi_out.register_fake def _( inputs: Sequence[torch.Tensor], op_type: str, onnx_dtypes: Sequence[int], *, shapes: Sequence[Sequence[Union[int, torch.SymInt]]], attr_keys: Sequence[str], attr_types: Sequence[str], attr_pos: Sequence[tuple[int, int]], attr_ints: Sequence[int], attr_floats: Sequence[float], attr_strs: Sequence[str], metadata_props_keys: Sequence[str] = (), metadata_props_values: Sequence[str] = (), domain: str = "", version: Optional[int] = None, ) -> list[torch.Tensor]: outputs = [] torch._check( len(shapes) == len(onnx_dtypes), lambda: f"Number of shapes ({len(shapes)}) must match number of ONNX dtypes ({len(onnx_dtypes)})", ) for shape, onnx_dtype in zip(shapes, onnx_dtypes): torch._check( onnx_dtype in _dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE, lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE.keys())}", ) # NOTE(justinchuby): Use zeros instead of torch.empty because I haven't figured # out how it can handle empty shapes outputs.append( torch.zeros( shape, dtype=_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype] ) ) return outputs