[ONNX] Allow exporter to export SDPA to Attention onnx operator (#154596)

Fixes [#149662](https://github.com/pytorch/pytorch/issues/149662)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154596
Approved by: https://github.com/justinchuby, https://github.com/titaiwangms

Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
This commit is contained in:
Kshitij Khode 2025-06-04 14:29:44 +00:00 committed by PyTorch MergeBot
parent 31d12b3955
commit ca0c2985d3
2 changed files with 256 additions and 2 deletions

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import logging
import pytest
import transformers
from onnxscript import ir
@ -629,6 +630,35 @@ class DynamoExporterTest(common_utils.TestCase):
[node.op_type for node in onnx_program.model.graph],
)
def test_graph_attention_opset_23(self):
class Model(torch.nn.Module):
def forward(self, query, key, value):
return torch.nn.functional.scaled_dot_product_attention(
query, key, value
)
query = torch.rand(32, 8, 128, 64, dtype=torch.float16)
key = torch.rand(32, 8, 128, 64, dtype=torch.float16)
value = torch.rand(32, 8, 128, 64, dtype=torch.float16)
onnx_program = self.export(Model(), (query, key, value), opset_version=23)
self.assertIn("Attention", [node.op_type for node in onnx_program.model.graph])
@pytest.mark.xfail(reason="Expected to fail until opset 23 is supported by ORT.")
def test_graph_accuracy_attention_opset_23(self):
class Model(torch.nn.Module):
def forward(self, query, key, value):
return torch.nn.functional.scaled_dot_product_attention(
query, key, value
)
query = torch.rand(32, 8, 128, 64, dtype=torch.float16)
key = torch.rand(32, 8, 128, 64, dtype=torch.float16)
value = torch.rand(32, 8, 128, 64, dtype=torch.float16)
onnx_program = self.export(Model(), (query, key, value), opset_version=23)
onnx_testing.assert_onnx_program(onnx_program, atol=1e-3, rtol=1)
if __name__ == "__main__":
common_utils.run_tests()

View File

@ -5,9 +5,9 @@
from __future__ import annotations
from typing import Optional
from typing import Optional, TYPE_CHECKING
from onnxscript.onnx_opset import opset20 as op20, opset21 as op21
from onnxscript.onnx_opset import opset20 as op20, opset21 as op21, opset23 as op23
import torch
from torch.onnx._internal._lazy_import import onnxscript_ir as ir
@ -15,8 +15,14 @@ from torch.onnx._internal.exporter._torchlib._tensor_typing import TFloat, TReal
from torch.onnx._internal.exporter._torchlib._torchlib_registry import onnx_impl
if TYPE_CHECKING:
from onnxscript.values import Opset
aten = torch.ops.aten
_INT64_MAX = 9223372036854775807
_INT64_MIN = -9223372036854775808
@onnx_impl(aten.gelu.default, trace_only=True, opset_introduced=20)
def aten_gelu_opset20(
@ -46,3 +52,221 @@ def aten_group_norm(
return op21.GroupNormalization(
input, weight, bias, epsilon=eps, num_groups=num_groups
)
@onnx_impl(
aten.scaled_dot_product_attention.default, trace_only=True, opset_introduced=23
)
def aten_scaled_dot_product_attention_23(
query: TFloat,
key: TFloat,
value: TFloat,
attn_mask: Optional[TFloat] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
) -> TFloat:
"""scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> Tensor
Reference:
1. https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
2. https://onnx.ai/onnx/operators/onnx__Attention.html
Attempts to convert SDPA to Attention onnx op and fallbacks to an onnx graph equivivalent to the following PyTorch code::
scale_factor = 1 / math.sqrt(Q.size(-1)) if scale is None else scale
attn_mask = (
torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
if is_causal
else attn_mask
)
attn_mask = (
attn_mask.masked_fill(not attn_mask, -float("inf"))
if attn_mask.dtype == torch.bool
else attn_mask
)
attn_weight = torch.softmax(
(Q @ K.transpose(-2, -1) * attn_mask, dim=-1
)
attn_weight = torch.dropout(attn_weight, dropout_p)
return attn_weight @ V
where Q, K, V are the query, key, and value tensors, respectively.
L is the target sequence length, S is the source sequence length, and E is the embedding size.
"""
assert (not is_causal) or (is_causal and attn_mask is None), (
"is_causal and attn_mask cannot be set at the same time"
)
# Attention onnx op can only handle non-training scenarios where dropout is disabled.
if dropout_p == 0:
if enable_gqa:
assert (
query.shape[1] > key.shape[1] == value.shape[1]
and query.shape[1] % key.shape[1] == 0
), (
"SDPA (GQA or MQA) requires q_num_heads > kv_num_heads & q_num_heads % kv_num_heads == 0"
)
else:
assert query.shape[1] == key.shape[1] == value.shape[1], (
"SDPA (MHA) requires q_num_heads = kv_num_heads"
)
Y, _, _, _ = op23.Attention(
query,
key,
value,
attn_mask=attn_mask,
scale=scale,
q_num_heads=query.shape[3],
kv_num_heads=key.shape[3],
is_causal=is_causal,
)
return Y
if scale is None:
scale = _attention_scale(query, op23)
scale = op23.CastLike(scale, query)
if is_causal:
attn_mask = _causal_attention_mask(query, key, op23)
if attn_mask is None:
return _aten_scaled_dot_product_attention_no_mask_onnx(
query, key, value, scale, dropout_p, op23
)
return _aten_scaled_dot_product_attention_float_mask_onnx(
query, key, value, attn_mask, scale, dropout_p, op23
)
def _attention_scale(query: TFloat, op: Opset) -> TFloat:
"""Calculate the scale factor for the attention result.
Args:
query: Tensor of shape [..., L, E]
Returns:
Scalar scale factor := 1 / math.sqrt(query.size(-1))
"""
q_shape = op.Shape(query)
q_last_dim = op.Gather(q_shape, op.Constant(value_ints=[-1]))
embedding_size = op.CastLike(q_last_dim, query)
one = op.Constant(value_float=1.0)
cast_one = op.CastLike(one, query)
scale = op.Div(cast_one, op.Sqrt(embedding_size))
return scale
def _causal_attention_mask(query: TFloat, key: TFloat, op: Opset) -> TFloat:
"""Create a causal mask for the given query and key tensors.
Equivalent to::
mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_mask = torch.zeros(L, S, dtype=torch.float)
attn_mask = attn_mask.masked_fill(not mask, -float("inf"))
Args:
query: Tensor of shape [..., L, E]
key: Tensor of shape [..., S, E]
Returns:
Tensor of shape [L, S]
"""
q_shape = op.Shape(query)
k_shape = op.Shape(key)
target_length = op.Slice(
q_shape, op.Constant(value_ints=[-2]), op.Constant(value_ints=[-1])
)
source_length = op.Slice(
k_shape, op.Constant(value_ints=[-2]), op.Constant(value_ints=[-1])
)
# attn_mask = torch.ones(L, S) := {
size = op.Concat(target_length, source_length, axis=0)
attn_mask = op.Expand(op.Constant(value_float=1.0), size)
# }
attn_mask = op.Trilu(attn_mask, upper=0)
# The causal mask has 0s in the lower triangle and -inf in the upper triangle.
attn_mask = op.Where(
op.Equal(attn_mask, op.Constant(value_float=0.0)),
op.Constant(value_float=-float("inf")),
op.Constant(value_float=0.0),
)
attn_mask = op.CastLike(attn_mask, query)
return attn_mask
def _aten_scaled_dot_product_attention_no_mask_onnx(
query: TFloat,
key: TFloat,
value: TFloat,
scale: TFloat,
dropout_p: float,
op: Opset,
) -> TFloat:
# Swap the last two axes of key
key_last_dim = op.Shape(key, start=-1)
key_second_last_dim = op.Shape(key, start=-2, end=-1)
key_first_dims = op.Shape(key, end=-2)
# Contract the dimensions that are not the last two so we can transpose
# with a static permutation.
key_squeezed_shape = op.Concat(
op.Constant(value_ints=[-1]), key_second_last_dim, key_last_dim, axis=0
)
key_squeezed = op.Reshape(key, key_squeezed_shape)
key_squeezed_transposed = op.Transpose(key_squeezed, perm=[0, 2, 1])
key_transposed_shape = op.Concat(
key_first_dims, key_last_dim, key_second_last_dim, axis=0
)
key_transposed = op.Reshape(key_squeezed_transposed, key_transposed_shape)
# https://github.com/pytorch/pytorch/blob/12da0c70378b5be9135c6fda62a9863bce4a4818/aten/src/ATen/native/transformers/attention.cpp#L653
# Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math
query_scaled = op.Mul(query, op.Sqrt(scale))
key_transposed_scaled = op.Mul(
key_transposed, op.CastLike(op.Sqrt(scale), key_transposed)
)
attn_weight = op.Softmax(
op.MatMul(query_scaled, key_transposed_scaled),
axis=-1,
)
attn_weight, _ = op.Dropout(attn_weight, dropout_p)
return op.MatMul(attn_weight, value)
def _aten_scaled_dot_product_attention_float_mask_onnx(
query: TFloat,
key: TFloat,
value: TFloat,
attn_mask: TFloat,
scale: TFloat,
dropout_p: float,
op: Opset,
) -> TFloat:
# Swap the last two axes of key
key_last_dim = op.Shape(key, start=-1)
key_second_last_dim = op.Shape(key, start=-2, end=-1)
key_first_dims = op.Shape(key, end=-2)
# Contract the dimensions that are not the last two so we can transpose
# with a static permutation.
key_squeezed_shape = op.Concat(
op.Constant(value_ints=[-1]), key_second_last_dim, key_last_dim, axis=0
)
key_squeezed = op.Reshape(key, key_squeezed_shape)
key_squeezed_transposed = op.Transpose(key_squeezed, perm=[0, 2, 1])
key_transposed_shape = op.Concat(
key_first_dims, key_last_dim, key_second_last_dim, axis=0
)
key_transposed = op.Reshape(key_squeezed_transposed, key_transposed_shape)
# https://github.com/pytorch/pytorch/blob/12da0c70378b5be9135c6fda62a9863bce4a4818/aten/src/ATen/native/transformers/attention.cpp#L653
# Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math
query_scaled = op.Mul(query, op.Sqrt(scale))
key_transposed_scaled = op.Mul(key_transposed, op.Sqrt(scale))
attn_weight = op.Softmax(
op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask),
axis=-1,
)
attn_weight, _ = op.Dropout(attn_weight, dropout_p)
return op.MatMul(attn_weight, value)