diff --git a/test/onnx/exporter/test_small_models_e2e.py b/test/onnx/exporter/test_small_models_e2e.py index 9b90e2f8784..c8ebe386ed8 100644 --- a/test/onnx/exporter/test_small_models_e2e.py +++ b/test/onnx/exporter/test_small_models_e2e.py @@ -5,6 +5,7 @@ from __future__ import annotations import logging +import pytest import transformers from onnxscript import ir @@ -629,6 +630,35 @@ class DynamoExporterTest(common_utils.TestCase): [node.op_type for node in onnx_program.model.graph], ) + def test_graph_attention_opset_23(self): + class Model(torch.nn.Module): + def forward(self, query, key, value): + return torch.nn.functional.scaled_dot_product_attention( + query, key, value + ) + + query = torch.rand(32, 8, 128, 64, dtype=torch.float16) + key = torch.rand(32, 8, 128, 64, dtype=torch.float16) + value = torch.rand(32, 8, 128, 64, dtype=torch.float16) + + onnx_program = self.export(Model(), (query, key, value), opset_version=23) + self.assertIn("Attention", [node.op_type for node in onnx_program.model.graph]) + + @pytest.mark.xfail(reason="Expected to fail until opset 23 is supported by ORT.") + def test_graph_accuracy_attention_opset_23(self): + class Model(torch.nn.Module): + def forward(self, query, key, value): + return torch.nn.functional.scaled_dot_product_attention( + query, key, value + ) + + query = torch.rand(32, 8, 128, 64, dtype=torch.float16) + key = torch.rand(32, 8, 128, 64, dtype=torch.float16) + value = torch.rand(32, 8, 128, 64, dtype=torch.float16) + + onnx_program = self.export(Model(), (query, key, value), opset_version=23) + onnx_testing.assert_onnx_program(onnx_program, atol=1e-3, rtol=1) + if __name__ == "__main__": common_utils.run_tests() diff --git a/torch/onnx/_internal/exporter/_torchlib/ops/nn.py b/torch/onnx/_internal/exporter/_torchlib/ops/nn.py index 9ac2aca9c31..438be0d9c3b 100644 --- a/torch/onnx/_internal/exporter/_torchlib/ops/nn.py +++ b/torch/onnx/_internal/exporter/_torchlib/ops/nn.py @@ -5,9 +5,9 @@ from __future__ import annotations -from typing import Optional +from typing import Optional, TYPE_CHECKING -from onnxscript.onnx_opset import opset20 as op20, opset21 as op21 +from onnxscript.onnx_opset import opset20 as op20, opset21 as op21, opset23 as op23 import torch from torch.onnx._internal._lazy_import import onnxscript_ir as ir @@ -15,8 +15,14 @@ from torch.onnx._internal.exporter._torchlib._tensor_typing import TFloat, TReal from torch.onnx._internal.exporter._torchlib._torchlib_registry import onnx_impl +if TYPE_CHECKING: + from onnxscript.values import Opset + aten = torch.ops.aten +_INT64_MAX = 9223372036854775807 +_INT64_MIN = -9223372036854775808 + @onnx_impl(aten.gelu.default, trace_only=True, opset_introduced=20) def aten_gelu_opset20( @@ -46,3 +52,221 @@ def aten_group_norm( return op21.GroupNormalization( input, weight, bias, epsilon=eps, num_groups=num_groups ) + + +@onnx_impl( + aten.scaled_dot_product_attention.default, trace_only=True, opset_introduced=23 +) +def aten_scaled_dot_product_attention_23( + query: TFloat, + key: TFloat, + value: TFloat, + attn_mask: Optional[TFloat] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> TFloat: + """scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> Tensor + + Reference: + 1. https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html + 2. https://onnx.ai/onnx/operators/onnx__Attention.html + + Attempts to convert SDPA to Attention onnx op and fallbacks to an onnx graph equivivalent to the following PyTorch code:: + scale_factor = 1 / math.sqrt(Q.size(-1)) if scale is None else scale + attn_mask = ( + torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + if is_causal + else attn_mask + ) + attn_mask = ( + attn_mask.masked_fill(not attn_mask, -float("inf")) + if attn_mask.dtype == torch.bool + else attn_mask + ) + attn_weight = torch.softmax( + (Q @ K.transpose(-2, -1) * attn_mask, dim=-1 + ) + attn_weight = torch.dropout(attn_weight, dropout_p) + return attn_weight @ V + + where Q, K, V are the query, key, and value tensors, respectively. + L is the target sequence length, S is the source sequence length, and E is the embedding size. + """ + assert (not is_causal) or (is_causal and attn_mask is None), ( + "is_causal and attn_mask cannot be set at the same time" + ) + + # Attention onnx op can only handle non-training scenarios where dropout is disabled. + if dropout_p == 0: + if enable_gqa: + assert ( + query.shape[1] > key.shape[1] == value.shape[1] + and query.shape[1] % key.shape[1] == 0 + ), ( + "SDPA (GQA or MQA) requires q_num_heads > kv_num_heads & q_num_heads % kv_num_heads == 0" + ) + else: + assert query.shape[1] == key.shape[1] == value.shape[1], ( + "SDPA (MHA) requires q_num_heads = kv_num_heads" + ) + Y, _, _, _ = op23.Attention( + query, + key, + value, + attn_mask=attn_mask, + scale=scale, + q_num_heads=query.shape[3], + kv_num_heads=key.shape[3], + is_causal=is_causal, + ) + return Y + + if scale is None: + scale = _attention_scale(query, op23) + scale = op23.CastLike(scale, query) + + if is_causal: + attn_mask = _causal_attention_mask(query, key, op23) + + if attn_mask is None: + return _aten_scaled_dot_product_attention_no_mask_onnx( + query, key, value, scale, dropout_p, op23 + ) + + return _aten_scaled_dot_product_attention_float_mask_onnx( + query, key, value, attn_mask, scale, dropout_p, op23 + ) + + +def _attention_scale(query: TFloat, op: Opset) -> TFloat: + """Calculate the scale factor for the attention result. + + Args: + query: Tensor of shape [..., L, E] + + Returns: + Scalar scale factor := 1 / math.sqrt(query.size(-1)) + """ + q_shape = op.Shape(query) + q_last_dim = op.Gather(q_shape, op.Constant(value_ints=[-1])) + embedding_size = op.CastLike(q_last_dim, query) + one = op.Constant(value_float=1.0) + cast_one = op.CastLike(one, query) + scale = op.Div(cast_one, op.Sqrt(embedding_size)) + return scale + + +def _causal_attention_mask(query: TFloat, key: TFloat, op: Opset) -> TFloat: + """Create a causal mask for the given query and key tensors. + + Equivalent to:: + mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_mask = torch.zeros(L, S, dtype=torch.float) + attn_mask = attn_mask.masked_fill(not mask, -float("inf")) + + Args: + query: Tensor of shape [..., L, E] + key: Tensor of shape [..., S, E] + + Returns: + Tensor of shape [L, S] + """ + q_shape = op.Shape(query) + k_shape = op.Shape(key) + + target_length = op.Slice( + q_shape, op.Constant(value_ints=[-2]), op.Constant(value_ints=[-1]) + ) + source_length = op.Slice( + k_shape, op.Constant(value_ints=[-2]), op.Constant(value_ints=[-1]) + ) + # attn_mask = torch.ones(L, S) := { + size = op.Concat(target_length, source_length, axis=0) + attn_mask = op.Expand(op.Constant(value_float=1.0), size) + # } + attn_mask = op.Trilu(attn_mask, upper=0) + # The causal mask has 0s in the lower triangle and -inf in the upper triangle. + attn_mask = op.Where( + op.Equal(attn_mask, op.Constant(value_float=0.0)), + op.Constant(value_float=-float("inf")), + op.Constant(value_float=0.0), + ) + attn_mask = op.CastLike(attn_mask, query) + return attn_mask + + +def _aten_scaled_dot_product_attention_no_mask_onnx( + query: TFloat, + key: TFloat, + value: TFloat, + scale: TFloat, + dropout_p: float, + op: Opset, +) -> TFloat: + # Swap the last two axes of key + key_last_dim = op.Shape(key, start=-1) + key_second_last_dim = op.Shape(key, start=-2, end=-1) + key_first_dims = op.Shape(key, end=-2) + # Contract the dimensions that are not the last two so we can transpose + # with a static permutation. + key_squeezed_shape = op.Concat( + op.Constant(value_ints=[-1]), key_second_last_dim, key_last_dim, axis=0 + ) + key_squeezed = op.Reshape(key, key_squeezed_shape) + key_squeezed_transposed = op.Transpose(key_squeezed, perm=[0, 2, 1]) + key_transposed_shape = op.Concat( + key_first_dims, key_last_dim, key_second_last_dim, axis=0 + ) + key_transposed = op.Reshape(key_squeezed_transposed, key_transposed_shape) + + # https://github.com/pytorch/pytorch/blob/12da0c70378b5be9135c6fda62a9863bce4a4818/aten/src/ATen/native/transformers/attention.cpp#L653 + # Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math + query_scaled = op.Mul(query, op.Sqrt(scale)) + key_transposed_scaled = op.Mul( + key_transposed, op.CastLike(op.Sqrt(scale), key_transposed) + ) + attn_weight = op.Softmax( + op.MatMul(query_scaled, key_transposed_scaled), + axis=-1, + ) + attn_weight, _ = op.Dropout(attn_weight, dropout_p) + return op.MatMul(attn_weight, value) + + +def _aten_scaled_dot_product_attention_float_mask_onnx( + query: TFloat, + key: TFloat, + value: TFloat, + attn_mask: TFloat, + scale: TFloat, + dropout_p: float, + op: Opset, +) -> TFloat: + # Swap the last two axes of key + key_last_dim = op.Shape(key, start=-1) + key_second_last_dim = op.Shape(key, start=-2, end=-1) + key_first_dims = op.Shape(key, end=-2) + # Contract the dimensions that are not the last two so we can transpose + # with a static permutation. + key_squeezed_shape = op.Concat( + op.Constant(value_ints=[-1]), key_second_last_dim, key_last_dim, axis=0 + ) + key_squeezed = op.Reshape(key, key_squeezed_shape) + key_squeezed_transposed = op.Transpose(key_squeezed, perm=[0, 2, 1]) + key_transposed_shape = op.Concat( + key_first_dims, key_last_dim, key_second_last_dim, axis=0 + ) + key_transposed = op.Reshape(key_squeezed_transposed, key_transposed_shape) + + # https://github.com/pytorch/pytorch/blob/12da0c70378b5be9135c6fda62a9863bce4a4818/aten/src/ATen/native/transformers/attention.cpp#L653 + # Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math + query_scaled = op.Mul(query, op.Sqrt(scale)) + key_transposed_scaled = op.Mul(key_transposed, op.Sqrt(scale)) + attn_weight = op.Softmax( + op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask), + axis=-1, + ) + attn_weight, _ = op.Dropout(attn_weight, dropout_p) + return op.MatMul(attn_weight, value)