diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp b/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp index d8d99897fac..0699d0a5604 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp @@ -49,16 +49,39 @@ struct SDPALogicalParams { "Only FP16/BF16/FP32 datatypes are currently supported"); const dims scalar_shape = {1}; std::vector inputLogicalTensors; + + at::Tensor reshaped_query = query_; + at::Tensor reshaped_key = key_; + at::Tensor reshaped_value = value_; + at::Tensor reshaped_output = output_; + at::Tensor reshaped_attn_mask = attn_mask_.value_or(at::Tensor()); + if (at::native::onednn::is_broadcast(reshaped_query)) { + at::native::onednn::undo_broadcast(reshaped_query); + } + if (at::native::onednn::is_broadcast(reshaped_key)) { + at::native::onednn::undo_broadcast(reshaped_key); + } + if (at::native::onednn::is_broadcast(reshaped_value)) { + at::native::onednn::undo_broadcast(reshaped_value); + } + if (at::native::onednn::is_broadcast(reshaped_output)) { + at::native::onednn::undo_broadcast(reshaped_output); + } + if (attn_mask_.has_value() && + at::native::onednn::is_broadcast(reshaped_attn_mask)) { + at::native::onednn::undo_broadcast(reshaped_attn_mask); + } + query = { static_cast(TensorID::query), dtype, - query_.sizes().vec(), - query_.strides().vec()}; + reshaped_query.sizes().vec(), + reshaped_query.strides().vec()}; key = { static_cast(TensorID::key), dtype, - key_.sizes().vec(), - key_.strides().vec()}; + reshaped_key.sizes().vec(), + reshaped_key.strides().vec()}; scale = { static_cast(TensorID::scale), dtype, @@ -77,19 +100,19 @@ struct SDPALogicalParams { attn_mask = { static_cast(TensorID::attn_mask), dtype, - attn_mask_->sizes().vec(), - attn_mask_->strides().vec()}; + reshaped_attn_mask.sizes().vec(), + reshaped_attn_mask.strides().vec()}; } value = { static_cast(TensorID::value), dtype, - value_.sizes().vec(), - value_.strides().vec()}; + reshaped_value.sizes().vec(), + reshaped_value.strides().vec()}; output = { static_cast(TensorID::output), dtype, - output_.sizes().vec(), - output_.strides().vec()}; + reshaped_output.sizes().vec(), + reshaped_output.strides().vec()}; } std::vector get_input() const { std::vector input = {query, key, scale}; diff --git a/test/test_transformers.py b/test/test_transformers.py index 051e5d1d5b8..d40c221d175 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -4006,6 +4006,28 @@ class TestSDPAXpuOnly(NNTestCase): with self.assertRaisesRegex(RuntimeError, "No available kernel."): _ = F.scaled_dot_product_attention(q, k, v) + def test_fused_attention_broadcasted_input(self, device): + dtype = torch.bfloat16 + make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=False) + batch, num_heads, seqlen, head_dim = 32, 16, 128, 32 + q_shape = SdpaShape(batch, num_heads, seqlen, head_dim) + k_shape = SdpaShape(batch, num_heads, seqlen, head_dim) + v_shape = SdpaShape(batch, num_heads, seqlen, head_dim) + query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape) + + attn_mask_shape = (1, seqlen) + attn_mask = make_tensor(attn_mask_shape) + attn_mask = attn_mask.expand(1, 1, seqlen, seqlen) + + # test that we do not dispatch to onednn for an unsupported case + actual = F.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) + + math_ref = torch.ops.aten._scaled_dot_product_attention_math( + query.float(), key.float(), value.float(), attn_mask=attn_mask, dropout_p=0.0, is_causal=False)[0] + + self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2) + @parametrize("type", ["dense"]) @parametrize("is_contiguous", [True, False]) def test_scaled_dot_product_attention_fused_kernels_packed(self, device, type: str, is_contiguous: bool):