mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Fix scaled_dot_product_attention with float scale (#135594)
Fixes #125158 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135594 Approved by: https://github.com/justinchuby
This commit is contained in:
parent
eb38ee21ba
commit
e48ee2cf50
|
|
@ -12534,6 +12534,27 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
self.run_test(M(), (x, y))
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(14)
|
||||
def test_scaled_dot_product_attention(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, q, k, v):
|
||||
return torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, scale=1.0
|
||||
)
|
||||
|
||||
# Parameters
|
||||
batch_size = 2 # Number of samples in the batch
|
||||
num_heads = 4 # Number of attention heads
|
||||
seq_length = 5 # Sequence length
|
||||
head_dim = 8 # Dimensionality of each head
|
||||
|
||||
# Create random query, key, and value tensors
|
||||
q = torch.randn(batch_size, num_heads, seq_length, head_dim)
|
||||
k = torch.randn(batch_size, num_heads, seq_length, head_dim)
|
||||
v = torch.randn(batch_size, num_heads, seq_length, head_dim)
|
||||
|
||||
self.run_test(M(), (q, k, v))
|
||||
|
||||
@skipScriptTest()
|
||||
@skipIfUnsupportedMinOpsetVersion(11)
|
||||
def test_dist_normal(self):
|
||||
|
|
|
|||
|
|
@ -150,7 +150,6 @@ def scaled_dot_product_attention(
|
|||
), "is_causal and attn_mask cannot be set at the same time"
|
||||
assert not enable_gqa, "conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
|
||||
|
||||
scale = symbolic_helper._maybe_get_const(scale, "f")
|
||||
if symbolic_helper._is_none(scale):
|
||||
scale = _attention_scale(g, query)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user