mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Allow exporter to export SDPA to Attention onnx operator (#154596)
Fixes [#149662](https://github.com/pytorch/pytorch/issues/149662) Pull Request resolved: https://github.com/pytorch/pytorch/pull/154596 Approved by: https://github.com/justinchuby, https://github.com/titaiwangms Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
This commit is contained in:
parent
31d12b3955
commit
ca0c2985d3
|
|
@ -5,6 +5,7 @@ from __future__ import annotations
|
|||
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
import transformers
|
||||
from onnxscript import ir
|
||||
|
||||
|
|
@ -629,6 +630,35 @@ class DynamoExporterTest(common_utils.TestCase):
|
|||
[node.op_type for node in onnx_program.model.graph],
|
||||
)
|
||||
|
||||
def test_graph_attention_opset_23(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, query, key, value):
|
||||
return torch.nn.functional.scaled_dot_product_attention(
|
||||
query, key, value
|
||||
)
|
||||
|
||||
query = torch.rand(32, 8, 128, 64, dtype=torch.float16)
|
||||
key = torch.rand(32, 8, 128, 64, dtype=torch.float16)
|
||||
value = torch.rand(32, 8, 128, 64, dtype=torch.float16)
|
||||
|
||||
onnx_program = self.export(Model(), (query, key, value), opset_version=23)
|
||||
self.assertIn("Attention", [node.op_type for node in onnx_program.model.graph])
|
||||
|
||||
@pytest.mark.xfail(reason="Expected to fail until opset 23 is supported by ORT.")
|
||||
def test_graph_accuracy_attention_opset_23(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, query, key, value):
|
||||
return torch.nn.functional.scaled_dot_product_attention(
|
||||
query, key, value
|
||||
)
|
||||
|
||||
query = torch.rand(32, 8, 128, 64, dtype=torch.float16)
|
||||
key = torch.rand(32, 8, 128, 64, dtype=torch.float16)
|
||||
value = torch.rand(32, 8, 128, 64, dtype=torch.float16)
|
||||
|
||||
onnx_program = self.export(Model(), (query, key, value), opset_version=23)
|
||||
onnx_testing.assert_onnx_program(onnx_program, atol=1e-3, rtol=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common_utils.run_tests()
|
||||
|
|
|
|||
|
|
@ -5,9 +5,9 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
from onnxscript.onnx_opset import opset20 as op20, opset21 as op21
|
||||
from onnxscript.onnx_opset import opset20 as op20, opset21 as op21, opset23 as op23
|
||||
|
||||
import torch
|
||||
from torch.onnx._internal._lazy_import import onnxscript_ir as ir
|
||||
|
|
@ -15,8 +15,14 @@ from torch.onnx._internal.exporter._torchlib._tensor_typing import TFloat, TReal
|
|||
from torch.onnx._internal.exporter._torchlib._torchlib_registry import onnx_impl
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onnxscript.values import Opset
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
_INT64_MAX = 9223372036854775807
|
||||
_INT64_MIN = -9223372036854775808
|
||||
|
||||
|
||||
@onnx_impl(aten.gelu.default, trace_only=True, opset_introduced=20)
|
||||
def aten_gelu_opset20(
|
||||
|
|
@ -46,3 +52,221 @@ def aten_group_norm(
|
|||
return op21.GroupNormalization(
|
||||
input, weight, bias, epsilon=eps, num_groups=num_groups
|
||||
)
|
||||
|
||||
|
||||
@onnx_impl(
|
||||
aten.scaled_dot_product_attention.default, trace_only=True, opset_introduced=23
|
||||
)
|
||||
def aten_scaled_dot_product_attention_23(
|
||||
query: TFloat,
|
||||
key: TFloat,
|
||||
value: TFloat,
|
||||
attn_mask: Optional[TFloat] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
) -> TFloat:
|
||||
"""scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> Tensor
|
||||
|
||||
Reference:
|
||||
1. https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
2. https://onnx.ai/onnx/operators/onnx__Attention.html
|
||||
|
||||
Attempts to convert SDPA to Attention onnx op and fallbacks to an onnx graph equivivalent to the following PyTorch code::
|
||||
scale_factor = 1 / math.sqrt(Q.size(-1)) if scale is None else scale
|
||||
attn_mask = (
|
||||
torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
|
||||
if is_causal
|
||||
else attn_mask
|
||||
)
|
||||
attn_mask = (
|
||||
attn_mask.masked_fill(not attn_mask, -float("inf"))
|
||||
if attn_mask.dtype == torch.bool
|
||||
else attn_mask
|
||||
)
|
||||
attn_weight = torch.softmax(
|
||||
(Q @ K.transpose(-2, -1) * attn_mask, dim=-1
|
||||
)
|
||||
attn_weight = torch.dropout(attn_weight, dropout_p)
|
||||
return attn_weight @ V
|
||||
|
||||
where Q, K, V are the query, key, and value tensors, respectively.
|
||||
L is the target sequence length, S is the source sequence length, and E is the embedding size.
|
||||
"""
|
||||
assert (not is_causal) or (is_causal and attn_mask is None), (
|
||||
"is_causal and attn_mask cannot be set at the same time"
|
||||
)
|
||||
|
||||
# Attention onnx op can only handle non-training scenarios where dropout is disabled.
|
||||
if dropout_p == 0:
|
||||
if enable_gqa:
|
||||
assert (
|
||||
query.shape[1] > key.shape[1] == value.shape[1]
|
||||
and query.shape[1] % key.shape[1] == 0
|
||||
), (
|
||||
"SDPA (GQA or MQA) requires q_num_heads > kv_num_heads & q_num_heads % kv_num_heads == 0"
|
||||
)
|
||||
else:
|
||||
assert query.shape[1] == key.shape[1] == value.shape[1], (
|
||||
"SDPA (MHA) requires q_num_heads = kv_num_heads"
|
||||
)
|
||||
Y, _, _, _ = op23.Attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attn_mask,
|
||||
scale=scale,
|
||||
q_num_heads=query.shape[3],
|
||||
kv_num_heads=key.shape[3],
|
||||
is_causal=is_causal,
|
||||
)
|
||||
return Y
|
||||
|
||||
if scale is None:
|
||||
scale = _attention_scale(query, op23)
|
||||
scale = op23.CastLike(scale, query)
|
||||
|
||||
if is_causal:
|
||||
attn_mask = _causal_attention_mask(query, key, op23)
|
||||
|
||||
if attn_mask is None:
|
||||
return _aten_scaled_dot_product_attention_no_mask_onnx(
|
||||
query, key, value, scale, dropout_p, op23
|
||||
)
|
||||
|
||||
return _aten_scaled_dot_product_attention_float_mask_onnx(
|
||||
query, key, value, attn_mask, scale, dropout_p, op23
|
||||
)
|
||||
|
||||
|
||||
def _attention_scale(query: TFloat, op: Opset) -> TFloat:
|
||||
"""Calculate the scale factor for the attention result.
|
||||
|
||||
Args:
|
||||
query: Tensor of shape [..., L, E]
|
||||
|
||||
Returns:
|
||||
Scalar scale factor := 1 / math.sqrt(query.size(-1))
|
||||
"""
|
||||
q_shape = op.Shape(query)
|
||||
q_last_dim = op.Gather(q_shape, op.Constant(value_ints=[-1]))
|
||||
embedding_size = op.CastLike(q_last_dim, query)
|
||||
one = op.Constant(value_float=1.0)
|
||||
cast_one = op.CastLike(one, query)
|
||||
scale = op.Div(cast_one, op.Sqrt(embedding_size))
|
||||
return scale
|
||||
|
||||
|
||||
def _causal_attention_mask(query: TFloat, key: TFloat, op: Opset) -> TFloat:
|
||||
"""Create a causal mask for the given query and key tensors.
|
||||
|
||||
Equivalent to::
|
||||
mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
|
||||
attn_mask = torch.zeros(L, S, dtype=torch.float)
|
||||
attn_mask = attn_mask.masked_fill(not mask, -float("inf"))
|
||||
|
||||
Args:
|
||||
query: Tensor of shape [..., L, E]
|
||||
key: Tensor of shape [..., S, E]
|
||||
|
||||
Returns:
|
||||
Tensor of shape [L, S]
|
||||
"""
|
||||
q_shape = op.Shape(query)
|
||||
k_shape = op.Shape(key)
|
||||
|
||||
target_length = op.Slice(
|
||||
q_shape, op.Constant(value_ints=[-2]), op.Constant(value_ints=[-1])
|
||||
)
|
||||
source_length = op.Slice(
|
||||
k_shape, op.Constant(value_ints=[-2]), op.Constant(value_ints=[-1])
|
||||
)
|
||||
# attn_mask = torch.ones(L, S) := {
|
||||
size = op.Concat(target_length, source_length, axis=0)
|
||||
attn_mask = op.Expand(op.Constant(value_float=1.0), size)
|
||||
# }
|
||||
attn_mask = op.Trilu(attn_mask, upper=0)
|
||||
# The causal mask has 0s in the lower triangle and -inf in the upper triangle.
|
||||
attn_mask = op.Where(
|
||||
op.Equal(attn_mask, op.Constant(value_float=0.0)),
|
||||
op.Constant(value_float=-float("inf")),
|
||||
op.Constant(value_float=0.0),
|
||||
)
|
||||
attn_mask = op.CastLike(attn_mask, query)
|
||||
return attn_mask
|
||||
|
||||
|
||||
def _aten_scaled_dot_product_attention_no_mask_onnx(
|
||||
query: TFloat,
|
||||
key: TFloat,
|
||||
value: TFloat,
|
||||
scale: TFloat,
|
||||
dropout_p: float,
|
||||
op: Opset,
|
||||
) -> TFloat:
|
||||
# Swap the last two axes of key
|
||||
key_last_dim = op.Shape(key, start=-1)
|
||||
key_second_last_dim = op.Shape(key, start=-2, end=-1)
|
||||
key_first_dims = op.Shape(key, end=-2)
|
||||
# Contract the dimensions that are not the last two so we can transpose
|
||||
# with a static permutation.
|
||||
key_squeezed_shape = op.Concat(
|
||||
op.Constant(value_ints=[-1]), key_second_last_dim, key_last_dim, axis=0
|
||||
)
|
||||
key_squeezed = op.Reshape(key, key_squeezed_shape)
|
||||
key_squeezed_transposed = op.Transpose(key_squeezed, perm=[0, 2, 1])
|
||||
key_transposed_shape = op.Concat(
|
||||
key_first_dims, key_last_dim, key_second_last_dim, axis=0
|
||||
)
|
||||
key_transposed = op.Reshape(key_squeezed_transposed, key_transposed_shape)
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/12da0c70378b5be9135c6fda62a9863bce4a4818/aten/src/ATen/native/transformers/attention.cpp#L653
|
||||
# Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math
|
||||
query_scaled = op.Mul(query, op.Sqrt(scale))
|
||||
key_transposed_scaled = op.Mul(
|
||||
key_transposed, op.CastLike(op.Sqrt(scale), key_transposed)
|
||||
)
|
||||
attn_weight = op.Softmax(
|
||||
op.MatMul(query_scaled, key_transposed_scaled),
|
||||
axis=-1,
|
||||
)
|
||||
attn_weight, _ = op.Dropout(attn_weight, dropout_p)
|
||||
return op.MatMul(attn_weight, value)
|
||||
|
||||
|
||||
def _aten_scaled_dot_product_attention_float_mask_onnx(
|
||||
query: TFloat,
|
||||
key: TFloat,
|
||||
value: TFloat,
|
||||
attn_mask: TFloat,
|
||||
scale: TFloat,
|
||||
dropout_p: float,
|
||||
op: Opset,
|
||||
) -> TFloat:
|
||||
# Swap the last two axes of key
|
||||
key_last_dim = op.Shape(key, start=-1)
|
||||
key_second_last_dim = op.Shape(key, start=-2, end=-1)
|
||||
key_first_dims = op.Shape(key, end=-2)
|
||||
# Contract the dimensions that are not the last two so we can transpose
|
||||
# with a static permutation.
|
||||
key_squeezed_shape = op.Concat(
|
||||
op.Constant(value_ints=[-1]), key_second_last_dim, key_last_dim, axis=0
|
||||
)
|
||||
key_squeezed = op.Reshape(key, key_squeezed_shape)
|
||||
key_squeezed_transposed = op.Transpose(key_squeezed, perm=[0, 2, 1])
|
||||
key_transposed_shape = op.Concat(
|
||||
key_first_dims, key_last_dim, key_second_last_dim, axis=0
|
||||
)
|
||||
key_transposed = op.Reshape(key_squeezed_transposed, key_transposed_shape)
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/12da0c70378b5be9135c6fda62a9863bce4a4818/aten/src/ATen/native/transformers/attention.cpp#L653
|
||||
# Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math
|
||||
query_scaled = op.Mul(query, op.Sqrt(scale))
|
||||
key_transposed_scaled = op.Mul(key_transposed, op.Sqrt(scale))
|
||||
attn_weight = op.Softmax(
|
||||
op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask),
|
||||
axis=-1,
|
||||
)
|
||||
attn_weight, _ = op.Dropout(attn_weight, dropout_p)
|
||||
return op.MatMul(attn_weight, value)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user