[ONNX] Fix scaled_dot_product_attention with float scale (#135594)

Fixes #125158

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135594
Approved by: https://github.com/justinchuby
This commit is contained in:
titaiwangms 2024-09-10 23:03:59 +00:00 committed by PyTorch MergeBot
parent eb38ee21ba
commit e48ee2cf50
2 changed files with 21 additions and 1 deletions

View File

@ -12534,6 +12534,27 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
self.run_test(M(), (x, y))
@skipIfUnsupportedMinOpsetVersion(14)
def test_scaled_dot_product_attention(self):
class M(torch.nn.Module):
def forward(self, q, k, v):
return torch.nn.functional.scaled_dot_product_attention(
q, k, v, scale=1.0
)
# Parameters
batch_size = 2 # Number of samples in the batch
num_heads = 4 # Number of attention heads
seq_length = 5 # Sequence length
head_dim = 8 # Dimensionality of each head
# Create random query, key, and value tensors
q = torch.randn(batch_size, num_heads, seq_length, head_dim)
k = torch.randn(batch_size, num_heads, seq_length, head_dim)
v = torch.randn(batch_size, num_heads, seq_length, head_dim)
self.run_test(M(), (q, k, v))
@skipScriptTest()
@skipIfUnsupportedMinOpsetVersion(11)
def test_dist_normal(self):

View File

@ -150,7 +150,6 @@ def scaled_dot_product_attention(
), "is_causal and attn_mask cannot be set at the same time"
assert not enable_gqa, "conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
scale = symbolic_helper._maybe_get_const(scale, "f")
if symbolic_helper._is_none(scale):
scale = _attention_scale(g, query)