diff --git a/aten/src/ATen/native/cudnn/MHA.cpp b/aten/src/ATen/native/cudnn/MHA.cpp index 5d146edb90b..48119a6a3b4 100644 --- a/aten/src/ATen/native/cudnn/MHA.cpp +++ b/aten/src/ATen/native/cudnn/MHA.cpp @@ -92,6 +92,7 @@ void run_cudnn_SDP_bprop( #include #include #include +#include #include #include @@ -319,88 +320,6 @@ auto fixSizeOneDimStrideSDPA( } return strides; } - -void alloc_with_matching_layout( - const Tensor& q, - Tensor& output, - const std::vector& shape) { - TORCH_INTERNAL_ASSERT( - shape.size() == q.sizes().size(), - "cuDNN SDPA alloc_with_matching_layout got requested shape ndim != q ndim"); - - if (std::equal(q.sizes().begin(), q.sizes().end(), shape.begin())) { - output = at::empty_like(q); - return; - } - - // get the "fill order," which is just an argsort on the strides - std::vector fill_order(shape.size()); - std::iota(fill_order.begin(), fill_order.end(), 0); - const auto q_strides = q.strides(); - std::stable_sort( - fill_order.begin(), fill_order.end(), [&q_strides](int idx1, int idx2) { - return q_strides[idx1] < q_strides[idx2]; - }); - std::vector ordered_strides(shape.size()); - int64_t current_stride = 1; - for (const int dim_idx : fill_order) { - ordered_strides[dim_idx] = current_stride; - current_stride *= shape[dim_idx]; - } - output = at::empty(at::IntArrayRef(shape), q.options()) - .as_strided( - at::IntArrayRef(shape), at::IntArrayRef(ordered_strides), 0); -} - -void permute_to_matching_layout(const Tensor& output, Tensor& grad_output) { - const int dims = output.sizes().size(); - std::vector outer_to_inner(dims); - std::iota(outer_to_inner.begin(), outer_to_inner.end(), 0); - const auto o_strides = output.strides(); - std::stable_sort( - outer_to_inner.begin(), - outer_to_inner.end(), - [&o_strides](int idx1, int idx2) { - return o_strides[idx1] > o_strides[idx2]; - }); - std::vector inverse(dims); - for (int d = 0; d < dims; d++) { - inverse[d] = std::find(outer_to_inner.begin(), outer_to_inner.end(), d) - - outer_to_inner.begin(); - } - grad_output = grad_output.permute(at::IntArrayRef(outer_to_inner)) - .contiguous() - .permute(at::IntArrayRef(inverse)); -} - -bool same_strides(const Tensor& t1, const Tensor& t2) { - std::vector t1_strides_no_ones; - std::vector t2_strides_no_ones; - const auto t1strides = t1.strides(); - const auto t2strides = t2.strides(); - const int dim = t1strides.size(); - if (dim != (int)t2strides.size()) { - return false; - } - const auto t1sizes = t1.sizes(); - const auto t2sizes = t2.sizes(); - - // we are going through strides backward here, but if both are backward it's - // comparable - for (int i = 0; i < dim; i++) { - if (t1sizes[i] > 1) { - t1_strides_no_ones.push_back(t1strides[i]); - } - if (t2sizes[i] > 1) { - t2_strides_no_ones.push_back(t2strides[i]); - } - } - return std::equal( - t1_strides_no_ones.begin(), - t1_strides_no_ones.end(), - t2_strides_no_ones.begin(), - t2_strides_no_ones.end()); -} } // namespace auto build_graph_and_tensors( diff --git a/aten/src/ATen/native/mkldnn/xpu/Attention.cpp b/aten/src/ATen/native/mkldnn/xpu/Attention.cpp index 6d51e4baf4f..0e5c7fcf3a0 100644 --- a/aten/src/ATen/native/mkldnn/xpu/Attention.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/Attention.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -192,9 +193,10 @@ _scaled_dot_product_fused_attention_overrideable_xpu( const int64_t seq_len_q = query.size(2); const int64_t seq_len_kv = key.size(2); - auto opts = query.options(); - auto output = - at::empty({batch_size, num_head_q, seq_len_q, head_dim_v}, opts); + at::Tensor output; + std::vector output_shape = { + batch_size, num_head_q, seq_len_q, head_dim_v}; + alloc_with_matching_layout(query, output, output_shape); at::Tensor logsumexp, debug_attn_mask; // not supported at::native::onednn::gpu_float_sdpa( diff --git a/aten/src/ATen/native/transformers/sdp_utils.h b/aten/src/ATen/native/transformers/sdp_utils.h new file mode 100644 index 00000000000..b19ed00ecc8 --- /dev/null +++ b/aten/src/ATen/native/transformers/sdp_utils.h @@ -0,0 +1,88 @@ +#pragma once +#include +#include + +namespace at::native { + +void alloc_with_matching_layout( + const Tensor& q, + Tensor& output, + const std::vector& shape) { + TORCH_INTERNAL_ASSERT( + shape.size() == q.sizes().size(), + "SDPA alloc_with_matching_layout got requested shape ndim != q ndim"); + + if (std::equal(q.sizes().begin(), q.sizes().end(), shape.begin())) { + output = at::empty_like(q); + return; + } + + // get the "fill order," which is just an argsort on the strides + std::vector fill_order(shape.size()); + std::iota(fill_order.begin(), fill_order.end(), 0); + const auto q_strides = q.strides(); + std::stable_sort( + fill_order.begin(), fill_order.end(), [&q_strides](int idx1, int idx2) { + return q_strides[idx1] < q_strides[idx2]; + }); + std::vector ordered_strides(shape.size()); + int64_t current_stride = 1; + for (const int dim_idx : fill_order) { + ordered_strides[dim_idx] = current_stride; + current_stride *= shape[dim_idx]; + } + output = at::empty(at::IntArrayRef(shape), q.options()) + .as_strided( + at::IntArrayRef(shape), at::IntArrayRef(ordered_strides), 0); +} + +void permute_to_matching_layout(const Tensor& output, Tensor& grad_output) { + const int dims = output.sizes().size(); + std::vector outer_to_inner(dims); + std::iota(outer_to_inner.begin(), outer_to_inner.end(), 0); + const auto o_strides = output.strides(); + std::stable_sort( + outer_to_inner.begin(), + outer_to_inner.end(), + [&o_strides](int idx1, int idx2) { + return o_strides[idx1] > o_strides[idx2]; + }); + std::vector inverse(dims); + for (int d = 0; d < dims; d++) { + inverse[d] = std::find(outer_to_inner.begin(), outer_to_inner.end(), d) - + outer_to_inner.begin(); + } + grad_output = grad_output.permute(at::IntArrayRef(outer_to_inner)) + .contiguous() + .permute(at::IntArrayRef(inverse)); +} + +bool same_strides(const Tensor& t1, const Tensor& t2) { + std::vector t1_strides_no_ones; + std::vector t2_strides_no_ones; + const auto t1strides = t1.strides(); + const auto t2strides = t2.strides(); + const int dim = t1strides.size(); + if (dim != (int)t2strides.size()) { + return false; + } + const auto t1sizes = t1.sizes(); + const auto t2sizes = t2.sizes(); + + // we are going through strides backward here, but if both are backward it's + // comparable + for (int i = 0; i < dim; i++) { + if (t1sizes[i] > 1) { + t1_strides_no_ones.push_back(t1strides[i]); + } + if (t2sizes[i] > 1) { + t2_strides_no_ones.push_back(t2strides[i]); + } + } + return std::equal( + t1_strides_no_ones.begin(), + t1_strides_no_ones.end(), + t2_strides_no_ones.begin(), + t2_strides_no_ones.end()); +} +} // namespace at::native diff --git a/test/test_transformers.py b/test/test_transformers.py index 58a1da52007..788a918f5b2 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -4132,6 +4132,31 @@ class TestSDPAXpuOnly(NNTestCase): self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2) + def test_attention_preserves_query_layout(self, device): + + def test_attention(permute_order: list[list[int]]): + BHSqD = [4, 16, 256, 64] + BHSkvD = [4, 16, 512, 64] + + shape_q = [BHSqD[idx] for idx in permute_order] + shape_kv = [BHSkvD[idx] for idx in permute_order] + reverse = [permute_order.index(idx) for idx in range(4)] + q = torch.randn(*shape_q, dtype=torch.bfloat16, device=device, requires_grad=False).permute(reverse) + k = torch.randn(*shape_kv, dtype=torch.bfloat16, device=device, requires_grad=False).permute(reverse) + v = torch.randn(*shape_kv, dtype=torch.bfloat16, device=device, requires_grad=False).permute(reverse) + self.assertEqual(q.shape, BHSqD) + self.assertEqual(k.shape, BHSkvD) + self.assertEqual(v.shape, BHSkvD) + + out = F.scaled_dot_product_attention(q, k, v) + self.assertTrue(out.permute(permute_order).is_contiguous()) + + permutable = [0, 1, 2] + permute_orders = itertools.permutations(permutable) + + for permute_order in permute_orders: + test_attention(list(permute_order) + [3]) + def test_scaled_dot_product_attention_fused_kernels_safe_softmax(self, device): dtype = torch.bfloat16 make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=False) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 9a027c66b12..aaab720456a 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -5832,6 +5832,26 @@ def meta__scaled_dot_product_flash_attention( ) +def alloc_with_matching_layout( + query: Tensor, + res_shape: tuple[int, ...], +): + if tuple(query.shape) == res_shape: + query_t = query.transpose(1, 2) + res = torch.empty_like(query_t).transpose(1, 2) + else: + dim_order = sorted( + [0, 1, 2, 3], key=lambda idx: query.stride()[idx], reverse=True + ) + permuted_shape = [res_shape[idx] for idx in dim_order] + final_permute = [dim_order.index(i) for i in range(len(dim_order))] + res = torch.empty( + permuted_shape, dtype=query.dtype, device=query.device + ).permute(final_permute) + + return res + + @register_meta([aten._scaled_dot_product_cudnn_attention]) def meta__scaled_dot_product_cudnn_attention( query: Tensor, @@ -5851,18 +5871,7 @@ def meta__scaled_dot_product_cudnn_attention( D_V = value.size(-1) res_shape = (B, H, S_Q, D_V) - if tuple(query.shape) == res_shape: - query_t = query.transpose(1, 2) - res = torch.empty_like(query_t).transpose(1, 2) - else: - dim_order = sorted( - [0, 1, 2, 3], key=lambda idx: query.stride()[idx], reverse=True - ) - permuted_shape = [res_shape[idx] for idx in dim_order] - final_permute = [dim_order.index(i) for i in range(len(dim_order))] - res = torch.empty( - permuted_shape, dtype=query.dtype, device=query.device - ).permute(final_permute) + res = alloc_with_matching_layout(query, res_shape) logsum_exp = torch.empty( (B, H, S_Q), @@ -5899,14 +5908,16 @@ def meta__scaled_dot_product_fused_attention_overrideable( scale: Optional[float] = None, ): B = query.size(0) - H = query.size(1) + H_Q = query.size(1) S_Q = query.size(2) S_KV = key.size(2) D_V = value.size(-1) - res = torch.empty((B, H, S_Q, D_V), dtype=query.dtype, device=query.device) + res_shape = (B, H_Q, S_Q, D_V) + res = alloc_with_matching_layout(query, res_shape) + logsum_exp = torch.empty( - (B, H, S_Q), + (B, H_Q, S_Q), dtype=torch.float, device=query.device, )