From e313152a3326ee6b5c2d967f122fd1d7150deaa7 Mon Sep 17 00:00:00 2001 From: Isalia20 Date: Wed, 28 May 2025 16:53:53 +0000 Subject: [PATCH] SDPA fix memory efficient attention for large batch dim (#154029) Fixes #146704 Pull Request resolved: https://github.com/pytorch/pytorch/pull/154029 Approved by: https://github.com/ngimel --- .../native/transformers/cuda/attention.cu | 114 ++++++++++++++---- test/test_transformers.py | 20 +++ 2 files changed, 110 insertions(+), 24 deletions(-) diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index b4aeab10a75..affca278ad1 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -46,6 +46,7 @@ #include #include #include +#include #include #include #include @@ -963,33 +964,98 @@ std::tuple _scaled_dot_product_efficient_attenti std::optional scale) { // Used for tracking usage statistics C10_LOG_API_USAGE_ONCE("torch.sdpa.mem_efficient_attention"); - // Query -> Query(Batch x Q_seq_len x Num_heads x Dim_per_head) - // Key -> Key(Batch x KV_seq_len x Num_heads x Dim_per_head) - // Value -> Value(Batch x KV_seq_len x Num_heads x Dim_per_head) - Tensor q_t = query.transpose(1, 2); - Tensor k_t = key.transpose(1, 2); - Tensor v_t = value.transpose(1, 2); + constexpr int64_t MAX_BATCH_SIZE = (1LL << 16) - 1; + int64_t batch_size = query.size(0); - sdp::CustomMaskType custom_mask_type = is_causal - ? sdp::CustomMaskType::CausalFromTopLeft - : sdp::CustomMaskType::NoCustomMask; + if (batch_size > MAX_BATCH_SIZE) { + TORCH_CHECK(!compute_log_sumexp && (dropout_p == 0.0), + "Efficient attention cannot produce valid seed, logsumexp and offset outputs when " + "the batch size exceeds (", MAX_BATCH_SIZE, ")."); + } + auto process_chunk = [&](const Tensor& q_chunk, + const Tensor& k_chunk, + const Tensor& v_chunk, + const std::optional& bias_chunk) + -> std::tuple { + Tensor q_t = q_chunk.transpose(1, 2); + Tensor k_t = k_chunk.transpose(1, 2); + Tensor v_t = v_chunk.transpose(1, 2); - auto [attention, log_sumexp, seed, offset, max_seqlen_batch_q, max_seqlen_batch_kv] = at::_efficient_attention_forward( - q_t, - k_t, - v_t, - attn_bias, - std::nullopt, - std::nullopt, - std::nullopt, - std::nullopt, - dropout_p, - static_cast(custom_mask_type), - compute_log_sumexp, - scale); + sdp::CustomMaskType custom_mask_type = is_causal + ? sdp::CustomMaskType::CausalFromTopLeft + : sdp::CustomMaskType::NoCustomMask; - attention = attention.transpose(1, 2); - return std::make_tuple(std::move(attention), std::move(log_sumexp), std::move(seed), std::move(offset)); + auto [attention, log_sumexp, seed, offset, max_seqlen_batch_q, max_seqlen_batch_kv] = + at::_efficient_attention_forward( + q_t, + k_t, + v_t, + bias_chunk, + std::nullopt, + std::nullopt, + std::nullopt, + std::nullopt, + dropout_p, + static_cast(custom_mask_type), + compute_log_sumexp, + scale); + attention = attention.transpose(1, 2); + + return std::make_tuple(std::move(attention), + std::move(log_sumexp), + std::move(seed), + std::move(offset)); + }; + + // when bs is larger than allowed maximum, process in chunks + if (batch_size > MAX_BATCH_SIZE) { + int64_t start = 0; + int64_t end = std::min(start + MAX_BATCH_SIZE, batch_size); + + Tensor query_chunk = query.slice(0, start, end); + Tensor key_chunk = key.slice(0, start, end); + Tensor value_chunk = value.slice(0, start, end); + std::optional bias_chunk; + if (attn_bias.has_value()) { + bias_chunk = attn_bias.value().slice(0, start, end); + } + auto [attn, log_sumexp, seed, offset] = + process_chunk(query_chunk, key_chunk, value_chunk, bias_chunk); + int dim = attn.dim(); + std::vector sizes; + sizes.reserve(dim); + sizes.push_back(batch_size); + for (int i = 1; i < dim; i++) { + sizes.push_back(attn.size(i)); + } + Tensor final_attention = at::empty_strided(sizes, attn.strides(), attn.options()); + final_attention.slice(0, start, end).copy_(attn); + + for (start = end; start < batch_size; start += MAX_BATCH_SIZE) { + end = std::min(start + MAX_BATCH_SIZE, batch_size); + query_chunk = query.slice(0, start, end); + key_chunk = key.slice(0, start, end); + value_chunk = value.slice(0, start, end); + if (attn_bias.has_value()) { + bias_chunk = attn_bias.value().slice(0, start, end); + } else { + bias_chunk.reset(); + } + + auto [chunk_attn, chunk_log_sumexp, chunk_seed, chunk_offset] = + process_chunk(query_chunk, key_chunk, value_chunk, bias_chunk); + final_attention.slice(0, start, end).copy_(chunk_attn); + } + + return std::make_tuple(std::move(final_attention), + std::move(log_sumexp), + std::move(seed), + std::move(offset)); + } + // when bs is within the allowed size, no need to chunk it + else { + return process_chunk(query, key, value, attn_bias); + } } int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Tensor& value, diff --git a/test/test_transformers.py b/test/test_transformers.py index b1284612697..34eaca2390d 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -1898,6 +1898,26 @@ class TestSDPAFailureModes(NNTestCase): self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( q, k, v, None, 0.0, is_causal=True)) + @onlyCUDA + def test_mem_eff_attention_fail_with_batch_size_geq_65536(self): + query = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16) + key = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16) + value = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16) + with sdpa_kernel(backends=SDPBackend.EFFICIENT_ATTENTION): + out = F.scaled_dot_product_attention(query, key, value) + out_cpu = F.scaled_dot_product_attention(query.cpu(), key.cpu(), value.cpu()) + self.assertEqual(out, out_cpu, atol=1e-3, rtol=1e-4) + + @onlyCUDA + def test_mem_eff_attention_fail_with_batch_size_geq_65536_error(self): + query = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16) + key = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16) + value = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16) + error_str = (r"Efficient attention cannot produce valid seed, " + r"logsumexp and offset outputs when the batch size exceeds \(65535\)\.") + with self.assertRaisesRegex(RuntimeError, error_str): + torch._scaled_dot_product_efficient_attention(query, key, value, attn_bias=None, compute_log_sumexp=True) + def _get_block_size_n(device, head_dim, is_dropout, is_causal): # This should match the block sizes in the CUDA kernel assert head_dim <= 256