# flake8: noqa: B950 import math from typing import Callable, Optional, TypeVar from typing_extensions import ParamSpec import torch from torch.onnx.ops import _dtype_mappings # Use ParamSpec for better type preservation instead of bound Callable TypeVar _P = ParamSpec("_P") _R = TypeVar("_R") # ONNX to ATen decomp table ONNX_ATEN_DECOMP_TABLE: dict[torch._ops.OpOverload, Callable] = {} _ATTENTION_23_ALLOWED_INTERMEDIATE_PRECISIONS = frozenset( { 1, # FLOAT 10, # FLOAT16 11, # DOUBLE 16, # BFLOAT16 } ) def _onnx_op( op_type: str, opset_version: int ) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: """Decorator to register an ONNX operator with a custom implementation.""" def decorator(func: Callable[_P, _R]) -> Callable[_P, _R]: overload = f"opset{opset_version}" torch_op = torch.library.custom_op( f"onnx::{op_type}.{overload}", mutates_args=() )(func) ONNX_ATEN_DECOMP_TABLE[getattr(getattr(torch.ops.onnx, op_type), overload)] = ( func # type: ignore[assignment] ) # Use the same implementation for the fake implementation # This is possible because we use pure aten ops to implement ONNX ops torch_op.register_fake(func) return torch_op # type: ignore[return-value] return decorator @_onnx_op("RotaryEmbedding", 23) def rotary_embedding_23( x: torch.Tensor, cos_cache: torch.Tensor, sin_cache: torch.Tensor, position_ids: Optional[torch.Tensor] = None, *, interleaved: bool = False, num_heads: int = 0, rotary_embedding_dim: int = 0, ) -> torch.Tensor: """RotaryEmbedding-23 https://onnx.ai/onnx/operators/onnx__RotaryEmbedding.html#rotaryembedding-23""" # x has shape (batch_size, num_heads, sequence_length, head_size) # or (batch_size, sequence_length, hidden_size) input_shape = x.shape input_rank = len(input_shape) batch_size = input_shape[0] sequence_length = input_shape[-2] # Validate position_ids and caches match x if position_ids is not None: torch._check( position_ids.dim() == 2, lambda: f"position_ids must be 2D when provided. Received shape {position_ids.shape}", ) torch._check( position_ids.shape[0] == batch_size, lambda: f"position_ids first dim (batch) must match x.shape[0] ({batch_size}). Received {position_ids.shape[0]}", ) torch._check( position_ids.shape[1] == sequence_length, lambda: f"position_ids second dim (sequence) must match x.shape[-2] ({sequence_length}). Received {position_ids.shape[1]}", ) torch._check( cos_cache.dim() == 2 and sin_cache.dim() == 2, lambda: "cos_cache/sin_cache must be 2D when position_ids is provided. " f"Received cos_cache shape {cos_cache.shape}, sin_cache shape {sin_cache.shape}", ) else: torch._check( cos_cache.dim() == 3 and sin_cache.dim() == 3, lambda: "cos_cache/sin_cache must be 3D when position_ids is not provided. " f"Received cos_cache shape {cos_cache.shape}, sin_cache shape {sin_cache.shape}", ) # First ensure x has shape [batch_size, num_heads, seq_len, head_size] # So that the rotation logic can be shared with reshaped 3D inputs if input_rank == 4: # Reshape from (batch_size, num_heads, seq_len, head_size) # to [batch_size, seq_len, num_heads, head_size] x = torch.permute(x, (0, 2, 1, 3)) elif input_rank == 3: torch._check( num_heads != 0, lambda: f"num_heads must be provided for 3D inputs. Received input tensor with shape {input_shape}", ) hidden_size = input_shape[2] head_size = hidden_size // num_heads new_shape = [batch_size, sequence_length, num_heads, head_size] x = torch.reshape(x, new_shape) torch._check(len(x.shape) == 4, lambda: "x should be a 4D tensor by now") head_size = x.shape[3] # Fully or partially perform rotation on x based on rotary_embedding_dim attribute if rotary_embedding_dim == 0: # If rotary_embedding_dim not provided, perform full rotation by using head_size rotary_embedding_dim = head_size x_rotate = x[:, :, :, :rotary_embedding_dim] x_not_rotate = x[:, :, :, rotary_embedding_dim:] rotary_embedding_dim_half = rotary_embedding_dim // 2 # Retrieve sin and cos caches using position ids if position_ids is not None: cos = cos_cache[ position_ids ] # Shape: [batch_size, sequence_length, head_size/2] sin = sin_cache[ position_ids ] # Shape: [batch_size, sequence_length, head_size/2] else: cos = cos_cache # Shape: [batch_size, sequence_length, rotary_embedding_dim/2] sin = sin_cache # Shape: [batch_size, sequence_length, rotary_embedding_dim/2] torch._check( cos.shape[0] == batch_size and cos.shape[1] == sequence_length, lambda: f"cos has shape {cos.shape} but expected (batch={batch_size}, seq={sequence_length}, ...)", ) torch._check( sin.shape[0] == batch_size and sin.shape[1] == sequence_length, lambda: f"sin has shape {sin.shape} but expected (batch={batch_size}, seq={sequence_length}, ...)", ) torch._check( cos.shape[-1] == rotary_embedding_dim_half, lambda: f"Last dimension of cos cache ({cos.shape[-1]}) should match rotary_embedding_dim/2 ({rotary_embedding_dim_half}).", ) torch._check( sin.shape[-1] == rotary_embedding_dim_half, lambda: f"Last dimension of sin cache ({sin.shape[-1]}) should match rotary_embedding_dim/2 ({rotary_embedding_dim_half}).", ) cos = torch.unsqueeze( cos, 2 ) # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2] sin = torch.unsqueeze( sin, 2 ) # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2] # Either divide the x in halves or interleave (based on interleaved attribute) if interleaved: x1 = x_rotate[:, :, :, 0::2] x2 = x_rotate[:, :, :, 1::2] else: x1, x2 = torch.chunk(x_rotate, 2, dim=-1) # Calculate real and imaginary values real = cos * x1 - sin * x2 imag = sin * x1 + cos * x2 # Inserted rotated embeddings back to the original x if interleaved: # x_rotate[:, :, :, 0::2] = real # x_rotate[:, :, :, 1::2] = imag real = torch.unsqueeze(real, -1) imag = torch.unsqueeze(imag, -1) x_rotate_concat = torch.cat((real, imag), dim=-1) x_rotate = torch.reshape(x_rotate_concat, x_rotate.shape) else: x_rotate = torch.cat((real, imag), dim=-1) output = torch.cat((x_rotate, x_not_rotate), dim=-1) if input_rank == 3: return torch.reshape(output, input_shape) # Return the dimensions to the original order return torch.permute(output, (0, 2, 1, 3)) def _get_scale_factor(scale: Optional[float], head_size: int) -> float: """Get the scale factor for attention computation.""" return scale if scale is not None else (1.0 / math.sqrt(head_size)) def _reshape_3d_to_4d( tensor: torch.Tensor, batch_size: int, num_heads: int ) -> torch.Tensor: """Reshape 3D tensor to 4D for multi-head attention.""" sequence_length, hidden_size = tensor.shape[1], tensor.shape[2] head_size = hidden_size // num_heads return ( tensor.view(batch_size, sequence_length, num_heads, head_size) .transpose(1, 2) .contiguous() ) def _get_qk_output_for_aten_spda( Q: torch.Tensor, K: torch.Tensor, current_q_num_heads: int, current_kv_num_heads: int, scale: Optional[float], qk_matmul_output_mode: int, ) -> torch.Tensor: """Get QK output tensor based on the specified mode.""" if qk_matmul_output_mode == 0: return _compute_qk_output_for_mode_0( Q, K, current_q_num_heads, current_kv_num_heads, scale ) else: # For other modes, return a zero tensor with correct shape return torch.zeros_like(torch.matmul(Q, K.transpose(-2, -1))) def _validate_gqa_configuration( current_q_num_heads: int, current_kv_num_heads: int ) -> None: """Validate Group Query Attention configuration.""" torch._check( current_q_num_heads % current_kv_num_heads == 0, lambda: f"q_num_heads ({current_q_num_heads}) must be divisible by kv_num_heads ({current_kv_num_heads}) for GQA", ) def _compute_qk_output_for_mode_0( Q: torch.Tensor, K: torch.Tensor, current_q_num_heads: int, current_kv_num_heads: int, scale: Optional[float], ) -> torch.Tensor: """Helper function to compute QK output for qk_matmul_output_mode == 0.""" # Handle GQA manually for QK output K_for_qk = K if current_q_num_heads != current_kv_num_heads: repeat_factor = current_q_num_heads // current_kv_num_heads K_for_qk = K.repeat_interleave(repeat_factor, dim=1) scale_factor = _get_scale_factor(scale, Q.shape[3]) # Scale both Q and K by sqrt(scale_factor) for numerical stability sqrt_scale = math.sqrt(scale_factor) Q_scaled = Q * sqrt_scale K_scaled = K_for_qk * sqrt_scale return torch.matmul(Q_scaled, K_scaled.transpose(-2, -1)) @_onnx_op("Attention", 23) def attention_23( Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, past_key: Optional[torch.Tensor] = None, past_value: Optional[torch.Tensor] = None, *, is_causal: bool = False, kv_num_heads: int = 0, q_num_heads: int = 0, qk_matmul_output_mode: int = 0, scale: Optional[float] = None, softcap: float = 0.0, softmax_precision: Optional[int] = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Attention-23 https://onnx.ai/onnx/operators/onnx__Attention.html#attention-23""" num_head_dim, sequence_dim, head_dim = 1, 2, 3 # Store original input shape to determine output shape input_shape_len = len(Q.shape) batch_size = Q.shape[0] # Reshape 3D inputs to 4D format if len(Q.shape) == 3: torch._check( q_num_heads != 0 and kv_num_heads != 0, lambda: "q_num_heads and kv_num_heads must be provided for 3D inputs", ) q_sequence_length = Q.shape[1] Q = _reshape_3d_to_4d(Q, batch_size, q_num_heads) K = _reshape_3d_to_4d(K, batch_size, kv_num_heads) V = _reshape_3d_to_4d(V, batch_size, kv_num_heads) torch._check( len(Q.shape) == 4 and len(K.shape) == 4 and len(V.shape) == 4, lambda: "Q, K, and V should be 4D tensors by now", ) # Calculate scale factor if not provided q_head_size = Q.shape[head_dim] scale = _get_scale_factor(scale, q_head_size) # Handle past key/value caches present_key = ( torch.cat([past_key, K], dim=sequence_dim) if past_key is not None else K.clone() ) present_value = ( torch.cat([past_value, V], dim=sequence_dim) if past_value is not None else V.clone() ) # Update K and V to include past states K, V = present_key, present_value # Get current dimensions current_q_num_heads = Q.shape[num_head_dim] current_kv_num_heads = K.shape[num_head_dim] q_sequence_length = Q.shape[sequence_dim] kv_sequence_length = K.shape[sequence_dim] # Check if we can use the optimized scaled_dot_product_attention (most optimized) can_use_sdpa = ( softcap == 0.0 # No softcap and qk_matmul_output_mode == 0 # Default QK output mode and softmax_precision is None # No custom softmax precision and (attn_mask is None or attn_mask.dtype == torch.bool) ) _validate_gqa_configuration(current_q_num_heads, current_kv_num_heads) if can_use_sdpa: # Use PyTorch's optimized scaled_dot_product_attention # Prepare attention mask for SDPA sdpa_attn_mask = None if attn_mask is not None: # Convert boolean mask: True means participate, SDPA expects True to mask out sdpa_attn_mask = ~attn_mask if attn_mask.dtype == torch.bool else attn_mask output = torch.nn.functional.scaled_dot_product_attention( Q, K, V, attn_mask=sdpa_attn_mask, dropout_p=0.0, is_causal=is_causal, scale=scale, enable_gqa=bool( current_q_num_heads != current_kv_num_heads ), # Ensure enable_gqa is not SymBool ) qk_output = _get_qk_output_for_aten_spda( Q, K, current_q_num_heads, current_kv_num_heads, scale, qk_matmul_output_mode, ) else: # Fallback to manual implementation for complex cases # Handle Group Query Attention (GQA) and Multi-Query Attention (MQA) if current_q_num_heads != current_kv_num_heads: repeat_factor = current_q_num_heads // current_kv_num_heads K = K.repeat_interleave(repeat_factor, dim=num_head_dim) V = V.repeat_interleave(repeat_factor, dim=num_head_dim) # Create attention bias attn_bias = torch.zeros( q_sequence_length, kv_sequence_length, dtype=Q.dtype, device=Q.device ) # Apply causal masking if is_causal: torch._check( attn_mask is None, lambda: "Cannot use both is_causal and attn_mask" ) causal_mask = torch.tril( torch.ones( q_sequence_length, kv_sequence_length, dtype=torch.bool, device=Q.device, ) ) attn_bias = attn_bias.masked_fill(~causal_mask, float("-inf")) # Apply attention mask if attn_mask is not None: if attn_mask.dtype == torch.bool: # Boolean mask: True means participate in attention attn_bias = attn_bias.masked_fill(~attn_mask, float("-inf")) else: # Float mask: added to attention scores attn_bias = attn_bias + attn_mask # Apply scaling factor scale_factor = _get_scale_factor(scale, Q.shape[3]) # Scale both Q and K by sqrt(scale_factor) for numerical stability sqrt_scale = math.sqrt(scale_factor) Q_scaled = Q * sqrt_scale K_scaled = K * sqrt_scale # Compute Q @ K^T qk_matmul_output = torch.matmul(Q_scaled, K_scaled.transpose(-2, -1)) # Initialize QK output based on mode qk_output = qk_matmul_output # Default case for mode 0 # Add attention bias qk_with_bias = qk_matmul_output + attn_bias if qk_matmul_output_mode == 1: qk_output = qk_with_bias # Apply softcap if provided if softcap > 0.0: qk_with_bias = softcap * torch.tanh(qk_with_bias / softcap) if qk_matmul_output_mode == 2: qk_output = qk_with_bias # Apply softmax with optional precision casting if softmax_precision is not None: # Map ONNX data type to torch dtype if softmax_precision in _ATTENTION_23_ALLOWED_INTERMEDIATE_PRECISIONS: original_dtype = qk_with_bias.dtype qk_with_bias = qk_with_bias.to( _dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE[softmax_precision] ) qk_softmax = torch.softmax(qk_with_bias, dim=-1) qk_softmax = qk_softmax.to(original_dtype) else: qk_softmax = torch.softmax(qk_with_bias, dim=-1) else: qk_softmax = torch.softmax(qk_with_bias, dim=-1) if qk_matmul_output_mode == 3: qk_output = qk_softmax # Compute attention output output = torch.matmul(qk_softmax, V) # Reshape output back to 3D if input was 3D if input_shape_len == 3: # output: (batch_size, q_num_heads, q_sequence_length, v_head_size) -> (batch_size, q_sequence_length, hidden_size) output = ( output.transpose(1, 2).contiguous().view(batch_size, q_sequence_length, -1) ) return output, present_key, present_value, qk_output