mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Reland] [Intel GPU] Make SDPA output has the same stride as Query. (#154340)
Fixes [#153903](https://github.com/pytorch/pytorch/issues/153903).
Currently the output tensor of SDPA XPU is always defined as contiguous stride, while CPU/CUDA flash_attention and cudnn_attention allocate output tensor with stride the same as Query.
This PR aligns XPU's behavior with CUDA/CPU to make XPU compatible to CPU/CUDA's modeling code.
The function `alloc_with_matching_layout` is copied from cudnn 8c16d0e404/aten/src/ATen/native/cudnn/MHA.cpp (L874)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154340
Approved by: https://github.com/guangyey, https://github.com/drisspg
This commit is contained in:
parent
a7b29c88b1
commit
04178d347c
|
|
@ -92,6 +92,7 @@ void run_cudnn_SDP_bprop(
|
|||
#include <ATen/cudnn/Types.h>
|
||||
#include <ATen/cudnn/Utils.h>
|
||||
#include <ATen/native/cudnn/MHA.h>
|
||||
#include <ATen/native/transformers/sdp_utils.h>
|
||||
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <cudnn_frontend.h>
|
||||
|
|
@ -319,88 +320,6 @@ auto fixSizeOneDimStrideSDPA(
|
|||
}
|
||||
return strides;
|
||||
}
|
||||
|
||||
void alloc_with_matching_layout(
|
||||
const Tensor& q,
|
||||
Tensor& output,
|
||||
const std::vector<int64_t>& shape) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
shape.size() == q.sizes().size(),
|
||||
"cuDNN SDPA alloc_with_matching_layout got requested shape ndim != q ndim");
|
||||
|
||||
if (std::equal(q.sizes().begin(), q.sizes().end(), shape.begin())) {
|
||||
output = at::empty_like(q);
|
||||
return;
|
||||
}
|
||||
|
||||
// get the "fill order," which is just an argsort on the strides
|
||||
std::vector<int> fill_order(shape.size());
|
||||
std::iota(fill_order.begin(), fill_order.end(), 0);
|
||||
const auto q_strides = q.strides();
|
||||
std::stable_sort(
|
||||
fill_order.begin(), fill_order.end(), [&q_strides](int idx1, int idx2) {
|
||||
return q_strides[idx1] < q_strides[idx2];
|
||||
});
|
||||
std::vector<int64_t> ordered_strides(shape.size());
|
||||
int64_t current_stride = 1;
|
||||
for (const int dim_idx : fill_order) {
|
||||
ordered_strides[dim_idx] = current_stride;
|
||||
current_stride *= shape[dim_idx];
|
||||
}
|
||||
output = at::empty(at::IntArrayRef(shape), q.options())
|
||||
.as_strided(
|
||||
at::IntArrayRef(shape), at::IntArrayRef(ordered_strides), 0);
|
||||
}
|
||||
|
||||
void permute_to_matching_layout(const Tensor& output, Tensor& grad_output) {
|
||||
const int dims = output.sizes().size();
|
||||
std::vector<int64_t> outer_to_inner(dims);
|
||||
std::iota(outer_to_inner.begin(), outer_to_inner.end(), 0);
|
||||
const auto o_strides = output.strides();
|
||||
std::stable_sort(
|
||||
outer_to_inner.begin(),
|
||||
outer_to_inner.end(),
|
||||
[&o_strides](int idx1, int idx2) {
|
||||
return o_strides[idx1] > o_strides[idx2];
|
||||
});
|
||||
std::vector<int64_t> inverse(dims);
|
||||
for (int d = 0; d < dims; d++) {
|
||||
inverse[d] = std::find(outer_to_inner.begin(), outer_to_inner.end(), d) -
|
||||
outer_to_inner.begin();
|
||||
}
|
||||
grad_output = grad_output.permute(at::IntArrayRef(outer_to_inner))
|
||||
.contiguous()
|
||||
.permute(at::IntArrayRef(inverse));
|
||||
}
|
||||
|
||||
bool same_strides(const Tensor& t1, const Tensor& t2) {
|
||||
std::vector<int> t1_strides_no_ones;
|
||||
std::vector<int> t2_strides_no_ones;
|
||||
const auto t1strides = t1.strides();
|
||||
const auto t2strides = t2.strides();
|
||||
const int dim = t1strides.size();
|
||||
if (dim != (int)t2strides.size()) {
|
||||
return false;
|
||||
}
|
||||
const auto t1sizes = t1.sizes();
|
||||
const auto t2sizes = t2.sizes();
|
||||
|
||||
// we are going through strides backward here, but if both are backward it's
|
||||
// comparable
|
||||
for (int i = 0; i < dim; i++) {
|
||||
if (t1sizes[i] > 1) {
|
||||
t1_strides_no_ones.push_back(t1strides[i]);
|
||||
}
|
||||
if (t2sizes[i] > 1) {
|
||||
t2_strides_no_ones.push_back(t2strides[i]);
|
||||
}
|
||||
}
|
||||
return std::equal(
|
||||
t1_strides_no_ones.begin(),
|
||||
t1_strides_no_ones.end(),
|
||||
t2_strides_no_ones.begin(),
|
||||
t2_strides_no_ones.end());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
auto build_graph_and_tensors(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
|
||||
#include <ATen/native/transformers/attention.h>
|
||||
#include <ATen/native/transformers/sdp_utils.h>
|
||||
#include <ATen/native/transformers/sdp_utils_cpp.h>
|
||||
#include <c10/util/Array.h>
|
||||
#include <torch/library.h>
|
||||
|
|
@ -192,9 +193,10 @@ _scaled_dot_product_fused_attention_overrideable_xpu(
|
|||
const int64_t seq_len_q = query.size(2);
|
||||
const int64_t seq_len_kv = key.size(2);
|
||||
|
||||
auto opts = query.options();
|
||||
auto output =
|
||||
at::empty({batch_size, num_head_q, seq_len_q, head_dim_v}, opts);
|
||||
at::Tensor output;
|
||||
std::vector<int64_t> output_shape = {
|
||||
batch_size, num_head_q, seq_len_q, head_dim_v};
|
||||
alloc_with_matching_layout(query, output, output_shape);
|
||||
at::Tensor logsumexp, debug_attn_mask; // not supported
|
||||
|
||||
at::native::onednn::gpu_float_sdpa(
|
||||
|
|
|
|||
88
aten/src/ATen/native/transformers/sdp_utils.h
Normal file
88
aten/src/ATen/native/transformers/sdp_utils.h
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
#pragma once
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
void alloc_with_matching_layout(
|
||||
const Tensor& q,
|
||||
Tensor& output,
|
||||
const std::vector<int64_t>& shape) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
shape.size() == q.sizes().size(),
|
||||
"SDPA alloc_with_matching_layout got requested shape ndim != q ndim");
|
||||
|
||||
if (std::equal(q.sizes().begin(), q.sizes().end(), shape.begin())) {
|
||||
output = at::empty_like(q);
|
||||
return;
|
||||
}
|
||||
|
||||
// get the "fill order," which is just an argsort on the strides
|
||||
std::vector<int> fill_order(shape.size());
|
||||
std::iota(fill_order.begin(), fill_order.end(), 0);
|
||||
const auto q_strides = q.strides();
|
||||
std::stable_sort(
|
||||
fill_order.begin(), fill_order.end(), [&q_strides](int idx1, int idx2) {
|
||||
return q_strides[idx1] < q_strides[idx2];
|
||||
});
|
||||
std::vector<int64_t> ordered_strides(shape.size());
|
||||
int64_t current_stride = 1;
|
||||
for (const int dim_idx : fill_order) {
|
||||
ordered_strides[dim_idx] = current_stride;
|
||||
current_stride *= shape[dim_idx];
|
||||
}
|
||||
output = at::empty(at::IntArrayRef(shape), q.options())
|
||||
.as_strided(
|
||||
at::IntArrayRef(shape), at::IntArrayRef(ordered_strides), 0);
|
||||
}
|
||||
|
||||
void permute_to_matching_layout(const Tensor& output, Tensor& grad_output) {
|
||||
const int dims = output.sizes().size();
|
||||
std::vector<int64_t> outer_to_inner(dims);
|
||||
std::iota(outer_to_inner.begin(), outer_to_inner.end(), 0);
|
||||
const auto o_strides = output.strides();
|
||||
std::stable_sort(
|
||||
outer_to_inner.begin(),
|
||||
outer_to_inner.end(),
|
||||
[&o_strides](int idx1, int idx2) {
|
||||
return o_strides[idx1] > o_strides[idx2];
|
||||
});
|
||||
std::vector<int64_t> inverse(dims);
|
||||
for (int d = 0; d < dims; d++) {
|
||||
inverse[d] = std::find(outer_to_inner.begin(), outer_to_inner.end(), d) -
|
||||
outer_to_inner.begin();
|
||||
}
|
||||
grad_output = grad_output.permute(at::IntArrayRef(outer_to_inner))
|
||||
.contiguous()
|
||||
.permute(at::IntArrayRef(inverse));
|
||||
}
|
||||
|
||||
bool same_strides(const Tensor& t1, const Tensor& t2) {
|
||||
std::vector<int> t1_strides_no_ones;
|
||||
std::vector<int> t2_strides_no_ones;
|
||||
const auto t1strides = t1.strides();
|
||||
const auto t2strides = t2.strides();
|
||||
const int dim = t1strides.size();
|
||||
if (dim != (int)t2strides.size()) {
|
||||
return false;
|
||||
}
|
||||
const auto t1sizes = t1.sizes();
|
||||
const auto t2sizes = t2.sizes();
|
||||
|
||||
// we are going through strides backward here, but if both are backward it's
|
||||
// comparable
|
||||
for (int i = 0; i < dim; i++) {
|
||||
if (t1sizes[i] > 1) {
|
||||
t1_strides_no_ones.push_back(t1strides[i]);
|
||||
}
|
||||
if (t2sizes[i] > 1) {
|
||||
t2_strides_no_ones.push_back(t2strides[i]);
|
||||
}
|
||||
}
|
||||
return std::equal(
|
||||
t1_strides_no_ones.begin(),
|
||||
t1_strides_no_ones.end(),
|
||||
t2_strides_no_ones.begin(),
|
||||
t2_strides_no_ones.end());
|
||||
}
|
||||
} // namespace at::native
|
||||
|
|
@ -4132,6 +4132,31 @@ class TestSDPAXpuOnly(NNTestCase):
|
|||
|
||||
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)
|
||||
|
||||
def test_attention_preserves_query_layout(self, device):
|
||||
|
||||
def test_attention(permute_order: list[list[int]]):
|
||||
BHSqD = [4, 16, 256, 64]
|
||||
BHSkvD = [4, 16, 512, 64]
|
||||
|
||||
shape_q = [BHSqD[idx] for idx in permute_order]
|
||||
shape_kv = [BHSkvD[idx] for idx in permute_order]
|
||||
reverse = [permute_order.index(idx) for idx in range(4)]
|
||||
q = torch.randn(*shape_q, dtype=torch.bfloat16, device=device, requires_grad=False).permute(reverse)
|
||||
k = torch.randn(*shape_kv, dtype=torch.bfloat16, device=device, requires_grad=False).permute(reverse)
|
||||
v = torch.randn(*shape_kv, dtype=torch.bfloat16, device=device, requires_grad=False).permute(reverse)
|
||||
self.assertEqual(q.shape, BHSqD)
|
||||
self.assertEqual(k.shape, BHSkvD)
|
||||
self.assertEqual(v.shape, BHSkvD)
|
||||
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
self.assertTrue(out.permute(permute_order).is_contiguous())
|
||||
|
||||
permutable = [0, 1, 2]
|
||||
permute_orders = itertools.permutations(permutable)
|
||||
|
||||
for permute_order in permute_orders:
|
||||
test_attention(list(permute_order) + [3])
|
||||
|
||||
def test_scaled_dot_product_attention_fused_kernels_safe_softmax(self, device):
|
||||
dtype = torch.bfloat16
|
||||
make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=False)
|
||||
|
|
|
|||
|
|
@ -5832,6 +5832,26 @@ def meta__scaled_dot_product_flash_attention(
|
|||
)
|
||||
|
||||
|
||||
def alloc_with_matching_layout(
|
||||
query: Tensor,
|
||||
res_shape: tuple[int, ...],
|
||||
):
|
||||
if tuple(query.shape) == res_shape:
|
||||
query_t = query.transpose(1, 2)
|
||||
res = torch.empty_like(query_t).transpose(1, 2)
|
||||
else:
|
||||
dim_order = sorted(
|
||||
[0, 1, 2, 3], key=lambda idx: query.stride()[idx], reverse=True
|
||||
)
|
||||
permuted_shape = [res_shape[idx] for idx in dim_order]
|
||||
final_permute = [dim_order.index(i) for i in range(len(dim_order))]
|
||||
res = torch.empty(
|
||||
permuted_shape, dtype=query.dtype, device=query.device
|
||||
).permute(final_permute)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
@register_meta([aten._scaled_dot_product_cudnn_attention])
|
||||
def meta__scaled_dot_product_cudnn_attention(
|
||||
query: Tensor,
|
||||
|
|
@ -5851,18 +5871,7 @@ def meta__scaled_dot_product_cudnn_attention(
|
|||
D_V = value.size(-1)
|
||||
|
||||
res_shape = (B, H, S_Q, D_V)
|
||||
if tuple(query.shape) == res_shape:
|
||||
query_t = query.transpose(1, 2)
|
||||
res = torch.empty_like(query_t).transpose(1, 2)
|
||||
else:
|
||||
dim_order = sorted(
|
||||
[0, 1, 2, 3], key=lambda idx: query.stride()[idx], reverse=True
|
||||
)
|
||||
permuted_shape = [res_shape[idx] for idx in dim_order]
|
||||
final_permute = [dim_order.index(i) for i in range(len(dim_order))]
|
||||
res = torch.empty(
|
||||
permuted_shape, dtype=query.dtype, device=query.device
|
||||
).permute(final_permute)
|
||||
res = alloc_with_matching_layout(query, res_shape)
|
||||
|
||||
logsum_exp = torch.empty(
|
||||
(B, H, S_Q),
|
||||
|
|
@ -5899,14 +5908,16 @@ def meta__scaled_dot_product_fused_attention_overrideable(
|
|||
scale: Optional[float] = None,
|
||||
):
|
||||
B = query.size(0)
|
||||
H = query.size(1)
|
||||
H_Q = query.size(1)
|
||||
S_Q = query.size(2)
|
||||
S_KV = key.size(2)
|
||||
D_V = value.size(-1)
|
||||
|
||||
res = torch.empty((B, H, S_Q, D_V), dtype=query.dtype, device=query.device)
|
||||
res_shape = (B, H_Q, S_Q, D_V)
|
||||
res = alloc_with_matching_layout(query, res_shape)
|
||||
|
||||
logsum_exp = torch.empty(
|
||||
(B, H, S_Q),
|
||||
(B, H_Q, S_Q),
|
||||
dtype=torch.float,
|
||||
device=query.device,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user