From e48ee2cf50d86d87ef7c7d0839267dbed4903ebf Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Tue, 10 Sep 2024 23:03:59 +0000 Subject: [PATCH] [ONNX] Fix scaled_dot_product_attention with float scale (#135594) Fixes #125158 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135594 Approved by: https://github.com/justinchuby --- test/onnx/test_pytorch_onnx_onnxruntime.py | 21 +++++++++++++++++++++ torch/onnx/symbolic_opset14.py | 1 - 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 091209a8a4f..c812b8e18b3 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -12534,6 +12534,27 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime): self.run_test(M(), (x, y)) + @skipIfUnsupportedMinOpsetVersion(14) + def test_scaled_dot_product_attention(self): + class M(torch.nn.Module): + def forward(self, q, k, v): + return torch.nn.functional.scaled_dot_product_attention( + q, k, v, scale=1.0 + ) + + # Parameters + batch_size = 2 # Number of samples in the batch + num_heads = 4 # Number of attention heads + seq_length = 5 # Sequence length + head_dim = 8 # Dimensionality of each head + + # Create random query, key, and value tensors + q = torch.randn(batch_size, num_heads, seq_length, head_dim) + k = torch.randn(batch_size, num_heads, seq_length, head_dim) + v = torch.randn(batch_size, num_heads, seq_length, head_dim) + + self.run_test(M(), (q, k, v)) + @skipScriptTest() @skipIfUnsupportedMinOpsetVersion(11) def test_dist_normal(self): diff --git a/torch/onnx/symbolic_opset14.py b/torch/onnx/symbolic_opset14.py index 1b10cc28531..ae33ddf58c6 100644 --- a/torch/onnx/symbolic_opset14.py +++ b/torch/onnx/symbolic_opset14.py @@ -150,7 +150,6 @@ def scaled_dot_product_attention( ), "is_causal and attn_mask cannot be set at the same time" assert not enable_gqa, "conversion of scaled_dot_product_attention not implemented if enable_gqa is True" - scale = symbolic_helper._maybe_get_const(scale, "f") if symbolic_helper._is_none(scale): scale = _attention_scale(g, query)