[Reland] [Intel GPU] Make SDPA output has the same stride as Query. (#154340)

Fixes [#153903](https://github.com/pytorch/pytorch/issues/153903).

Currently the output tensor of SDPA XPU is always defined as contiguous stride, while CPU/CUDA flash_attention and cudnn_attention allocate output tensor with stride the same as Query.

This PR aligns XPU's behavior with CUDA/CPU to make XPU compatible to CPU/CUDA's modeling code.

The function `alloc_with_matching_layout` is copied from cudnn 8c16d0e404/aten/src/ATen/native/cudnn/MHA.cpp (L874)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154340
Approved by: https://github.com/guangyey, https://github.com/drisspg
This commit is contained in:
fengqing.lu 2025-06-24 06:09:54 +00:00 committed by PyTorch MergeBot
parent a7b29c88b1
commit 04178d347c
5 changed files with 145 additions and 100 deletions

View File

@ -92,6 +92,7 @@ void run_cudnn_SDP_bprop(
#include <ATen/cudnn/Types.h>
#include <ATen/cudnn/Utils.h>
#include <ATen/native/cudnn/MHA.h>
#include <ATen/native/transformers/sdp_utils.h>
#include <ATen/cuda/Exceptions.h>
#include <cudnn_frontend.h>
@ -319,88 +320,6 @@ auto fixSizeOneDimStrideSDPA(
}
return strides;
}
void alloc_with_matching_layout(
const Tensor& q,
Tensor& output,
const std::vector<int64_t>& shape) {
TORCH_INTERNAL_ASSERT(
shape.size() == q.sizes().size(),
"cuDNN SDPA alloc_with_matching_layout got requested shape ndim != q ndim");
if (std::equal(q.sizes().begin(), q.sizes().end(), shape.begin())) {
output = at::empty_like(q);
return;
}
// get the "fill order," which is just an argsort on the strides
std::vector<int> fill_order(shape.size());
std::iota(fill_order.begin(), fill_order.end(), 0);
const auto q_strides = q.strides();
std::stable_sort(
fill_order.begin(), fill_order.end(), [&q_strides](int idx1, int idx2) {
return q_strides[idx1] < q_strides[idx2];
});
std::vector<int64_t> ordered_strides(shape.size());
int64_t current_stride = 1;
for (const int dim_idx : fill_order) {
ordered_strides[dim_idx] = current_stride;
current_stride *= shape[dim_idx];
}
output = at::empty(at::IntArrayRef(shape), q.options())
.as_strided(
at::IntArrayRef(shape), at::IntArrayRef(ordered_strides), 0);
}
void permute_to_matching_layout(const Tensor& output, Tensor& grad_output) {
const int dims = output.sizes().size();
std::vector<int64_t> outer_to_inner(dims);
std::iota(outer_to_inner.begin(), outer_to_inner.end(), 0);
const auto o_strides = output.strides();
std::stable_sort(
outer_to_inner.begin(),
outer_to_inner.end(),
[&o_strides](int idx1, int idx2) {
return o_strides[idx1] > o_strides[idx2];
});
std::vector<int64_t> inverse(dims);
for (int d = 0; d < dims; d++) {
inverse[d] = std::find(outer_to_inner.begin(), outer_to_inner.end(), d) -
outer_to_inner.begin();
}
grad_output = grad_output.permute(at::IntArrayRef(outer_to_inner))
.contiguous()
.permute(at::IntArrayRef(inverse));
}
bool same_strides(const Tensor& t1, const Tensor& t2) {
std::vector<int> t1_strides_no_ones;
std::vector<int> t2_strides_no_ones;
const auto t1strides = t1.strides();
const auto t2strides = t2.strides();
const int dim = t1strides.size();
if (dim != (int)t2strides.size()) {
return false;
}
const auto t1sizes = t1.sizes();
const auto t2sizes = t2.sizes();
// we are going through strides backward here, but if both are backward it's
// comparable
for (int i = 0; i < dim; i++) {
if (t1sizes[i] > 1) {
t1_strides_no_ones.push_back(t1strides[i]);
}
if (t2sizes[i] > 1) {
t2_strides_no_ones.push_back(t2strides[i]);
}
}
return std::equal(
t1_strides_no_ones.begin(),
t1_strides_no_ones.end(),
t2_strides_no_ones.begin(),
t2_strides_no_ones.end());
}
} // namespace
auto build_graph_and_tensors(

View File

@ -1,5 +1,6 @@
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
#include <ATen/native/transformers/attention.h>
#include <ATen/native/transformers/sdp_utils.h>
#include <ATen/native/transformers/sdp_utils_cpp.h>
#include <c10/util/Array.h>
#include <torch/library.h>
@ -192,9 +193,10 @@ _scaled_dot_product_fused_attention_overrideable_xpu(
const int64_t seq_len_q = query.size(2);
const int64_t seq_len_kv = key.size(2);
auto opts = query.options();
auto output =
at::empty({batch_size, num_head_q, seq_len_q, head_dim_v}, opts);
at::Tensor output;
std::vector<int64_t> output_shape = {
batch_size, num_head_q, seq_len_q, head_dim_v};
alloc_with_matching_layout(query, output, output_shape);
at::Tensor logsumexp, debug_attn_mask; // not supported
at::native::onednn::gpu_float_sdpa(

View File

@ -0,0 +1,88 @@
#pragma once
#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
namespace at::native {
void alloc_with_matching_layout(
const Tensor& q,
Tensor& output,
const std::vector<int64_t>& shape) {
TORCH_INTERNAL_ASSERT(
shape.size() == q.sizes().size(),
"SDPA alloc_with_matching_layout got requested shape ndim != q ndim");
if (std::equal(q.sizes().begin(), q.sizes().end(), shape.begin())) {
output = at::empty_like(q);
return;
}
// get the "fill order," which is just an argsort on the strides
std::vector<int> fill_order(shape.size());
std::iota(fill_order.begin(), fill_order.end(), 0);
const auto q_strides = q.strides();
std::stable_sort(
fill_order.begin(), fill_order.end(), [&q_strides](int idx1, int idx2) {
return q_strides[idx1] < q_strides[idx2];
});
std::vector<int64_t> ordered_strides(shape.size());
int64_t current_stride = 1;
for (const int dim_idx : fill_order) {
ordered_strides[dim_idx] = current_stride;
current_stride *= shape[dim_idx];
}
output = at::empty(at::IntArrayRef(shape), q.options())
.as_strided(
at::IntArrayRef(shape), at::IntArrayRef(ordered_strides), 0);
}
void permute_to_matching_layout(const Tensor& output, Tensor& grad_output) {
const int dims = output.sizes().size();
std::vector<int64_t> outer_to_inner(dims);
std::iota(outer_to_inner.begin(), outer_to_inner.end(), 0);
const auto o_strides = output.strides();
std::stable_sort(
outer_to_inner.begin(),
outer_to_inner.end(),
[&o_strides](int idx1, int idx2) {
return o_strides[idx1] > o_strides[idx2];
});
std::vector<int64_t> inverse(dims);
for (int d = 0; d < dims; d++) {
inverse[d] = std::find(outer_to_inner.begin(), outer_to_inner.end(), d) -
outer_to_inner.begin();
}
grad_output = grad_output.permute(at::IntArrayRef(outer_to_inner))
.contiguous()
.permute(at::IntArrayRef(inverse));
}
bool same_strides(const Tensor& t1, const Tensor& t2) {
std::vector<int> t1_strides_no_ones;
std::vector<int> t2_strides_no_ones;
const auto t1strides = t1.strides();
const auto t2strides = t2.strides();
const int dim = t1strides.size();
if (dim != (int)t2strides.size()) {
return false;
}
const auto t1sizes = t1.sizes();
const auto t2sizes = t2.sizes();
// we are going through strides backward here, but if both are backward it's
// comparable
for (int i = 0; i < dim; i++) {
if (t1sizes[i] > 1) {
t1_strides_no_ones.push_back(t1strides[i]);
}
if (t2sizes[i] > 1) {
t2_strides_no_ones.push_back(t2strides[i]);
}
}
return std::equal(
t1_strides_no_ones.begin(),
t1_strides_no_ones.end(),
t2_strides_no_ones.begin(),
t2_strides_no_ones.end());
}
} // namespace at::native

View File

@ -4132,6 +4132,31 @@ class TestSDPAXpuOnly(NNTestCase):
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)
def test_attention_preserves_query_layout(self, device):
def test_attention(permute_order: list[list[int]]):
BHSqD = [4, 16, 256, 64]
BHSkvD = [4, 16, 512, 64]
shape_q = [BHSqD[idx] for idx in permute_order]
shape_kv = [BHSkvD[idx] for idx in permute_order]
reverse = [permute_order.index(idx) for idx in range(4)]
q = torch.randn(*shape_q, dtype=torch.bfloat16, device=device, requires_grad=False).permute(reverse)
k = torch.randn(*shape_kv, dtype=torch.bfloat16, device=device, requires_grad=False).permute(reverse)
v = torch.randn(*shape_kv, dtype=torch.bfloat16, device=device, requires_grad=False).permute(reverse)
self.assertEqual(q.shape, BHSqD)
self.assertEqual(k.shape, BHSkvD)
self.assertEqual(v.shape, BHSkvD)
out = F.scaled_dot_product_attention(q, k, v)
self.assertTrue(out.permute(permute_order).is_contiguous())
permutable = [0, 1, 2]
permute_orders = itertools.permutations(permutable)
for permute_order in permute_orders:
test_attention(list(permute_order) + [3])
def test_scaled_dot_product_attention_fused_kernels_safe_softmax(self, device):
dtype = torch.bfloat16
make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=False)

View File

@ -5832,6 +5832,26 @@ def meta__scaled_dot_product_flash_attention(
)
def alloc_with_matching_layout(
query: Tensor,
res_shape: tuple[int, ...],
):
if tuple(query.shape) == res_shape:
query_t = query.transpose(1, 2)
res = torch.empty_like(query_t).transpose(1, 2)
else:
dim_order = sorted(
[0, 1, 2, 3], key=lambda idx: query.stride()[idx], reverse=True
)
permuted_shape = [res_shape[idx] for idx in dim_order]
final_permute = [dim_order.index(i) for i in range(len(dim_order))]
res = torch.empty(
permuted_shape, dtype=query.dtype, device=query.device
).permute(final_permute)
return res
@register_meta([aten._scaled_dot_product_cudnn_attention])
def meta__scaled_dot_product_cudnn_attention(
query: Tensor,
@ -5851,18 +5871,7 @@ def meta__scaled_dot_product_cudnn_attention(
D_V = value.size(-1)
res_shape = (B, H, S_Q, D_V)
if tuple(query.shape) == res_shape:
query_t = query.transpose(1, 2)
res = torch.empty_like(query_t).transpose(1, 2)
else:
dim_order = sorted(
[0, 1, 2, 3], key=lambda idx: query.stride()[idx], reverse=True
)
permuted_shape = [res_shape[idx] for idx in dim_order]
final_permute = [dim_order.index(i) for i in range(len(dim_order))]
res = torch.empty(
permuted_shape, dtype=query.dtype, device=query.device
).permute(final_permute)
res = alloc_with_matching_layout(query, res_shape)
logsum_exp = torch.empty(
(B, H, S_Q),
@ -5899,14 +5908,16 @@ def meta__scaled_dot_product_fused_attention_overrideable(
scale: Optional[float] = None,
):
B = query.size(0)
H = query.size(1)
H_Q = query.size(1)
S_Q = query.size(2)
S_KV = key.size(2)
D_V = value.size(-1)
res = torch.empty((B, H, S_Q, D_V), dtype=query.dtype, device=query.device)
res_shape = (B, H_Q, S_Q, D_V)
res = alloc_with_matching_layout(query, res_shape)
logsum_exp = torch.empty(
(B, H, S_Q),
(B, H_Q, S_Q),
dtype=torch.float,
device=query.device,
)