mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
SDPA fix memory efficient attention for large batch dim (#154029)
Fixes #146704 Pull Request resolved: https://github.com/pytorch/pytorch/pull/154029 Approved by: https://github.com/ngimel
This commit is contained in:
parent
3b38989b5f
commit
e313152a33
|
|
@ -46,6 +46,7 @@
|
||||||
#include <ATen/ops/_triton_multi_head_attention_native.h>
|
#include <ATen/ops/_triton_multi_head_attention_native.h>
|
||||||
#include <ATen/ops/_triton_scaled_dot_attention.h>
|
#include <ATen/ops/_triton_scaled_dot_attention.h>
|
||||||
#include <ATen/ops/empty.h>
|
#include <ATen/ops/empty.h>
|
||||||
|
#include <ATen/ops/empty_strided.h>
|
||||||
#include <ATen/ops/empty_like.h>
|
#include <ATen/ops/empty_like.h>
|
||||||
#include <ATen/ops/linear.h>
|
#include <ATen/ops/linear.h>
|
||||||
#include <ATen/ops/narrow_native.h>
|
#include <ATen/ops/narrow_native.h>
|
||||||
|
|
@ -963,22 +964,33 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti
|
||||||
std::optional<double> scale) {
|
std::optional<double> scale) {
|
||||||
// Used for tracking usage statistics
|
// Used for tracking usage statistics
|
||||||
C10_LOG_API_USAGE_ONCE("torch.sdpa.mem_efficient_attention");
|
C10_LOG_API_USAGE_ONCE("torch.sdpa.mem_efficient_attention");
|
||||||
// Query -> Query(Batch x Q_seq_len x Num_heads x Dim_per_head)
|
constexpr int64_t MAX_BATCH_SIZE = (1LL << 16) - 1;
|
||||||
// Key -> Key(Batch x KV_seq_len x Num_heads x Dim_per_head)
|
int64_t batch_size = query.size(0);
|
||||||
// Value -> Value(Batch x KV_seq_len x Num_heads x Dim_per_head)
|
|
||||||
Tensor q_t = query.transpose(1, 2);
|
if (batch_size > MAX_BATCH_SIZE) {
|
||||||
Tensor k_t = key.transpose(1, 2);
|
TORCH_CHECK(!compute_log_sumexp && (dropout_p == 0.0),
|
||||||
Tensor v_t = value.transpose(1, 2);
|
"Efficient attention cannot produce valid seed, logsumexp and offset outputs when "
|
||||||
|
"the batch size exceeds (", MAX_BATCH_SIZE, ").");
|
||||||
|
}
|
||||||
|
auto process_chunk = [&](const Tensor& q_chunk,
|
||||||
|
const Tensor& k_chunk,
|
||||||
|
const Tensor& v_chunk,
|
||||||
|
const std::optional<Tensor>& bias_chunk)
|
||||||
|
-> std::tuple<Tensor, Tensor, Tensor, Tensor> {
|
||||||
|
Tensor q_t = q_chunk.transpose(1, 2);
|
||||||
|
Tensor k_t = k_chunk.transpose(1, 2);
|
||||||
|
Tensor v_t = v_chunk.transpose(1, 2);
|
||||||
|
|
||||||
sdp::CustomMaskType custom_mask_type = is_causal
|
sdp::CustomMaskType custom_mask_type = is_causal
|
||||||
? sdp::CustomMaskType::CausalFromTopLeft
|
? sdp::CustomMaskType::CausalFromTopLeft
|
||||||
: sdp::CustomMaskType::NoCustomMask;
|
: sdp::CustomMaskType::NoCustomMask;
|
||||||
|
|
||||||
auto [attention, log_sumexp, seed, offset, max_seqlen_batch_q, max_seqlen_batch_kv] = at::_efficient_attention_forward(
|
auto [attention, log_sumexp, seed, offset, max_seqlen_batch_q, max_seqlen_batch_kv] =
|
||||||
|
at::_efficient_attention_forward(
|
||||||
q_t,
|
q_t,
|
||||||
k_t,
|
k_t,
|
||||||
v_t,
|
v_t,
|
||||||
attn_bias,
|
bias_chunk,
|
||||||
std::nullopt,
|
std::nullopt,
|
||||||
std::nullopt,
|
std::nullopt,
|
||||||
std::nullopt,
|
std::nullopt,
|
||||||
|
|
@ -987,9 +999,63 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti
|
||||||
static_cast<int64_t>(custom_mask_type),
|
static_cast<int64_t>(custom_mask_type),
|
||||||
compute_log_sumexp,
|
compute_log_sumexp,
|
||||||
scale);
|
scale);
|
||||||
|
|
||||||
attention = attention.transpose(1, 2);
|
attention = attention.transpose(1, 2);
|
||||||
return std::make_tuple(std::move(attention), std::move(log_sumexp), std::move(seed), std::move(offset));
|
|
||||||
|
return std::make_tuple(std::move(attention),
|
||||||
|
std::move(log_sumexp),
|
||||||
|
std::move(seed),
|
||||||
|
std::move(offset));
|
||||||
|
};
|
||||||
|
|
||||||
|
// when bs is larger than allowed maximum, process in chunks
|
||||||
|
if (batch_size > MAX_BATCH_SIZE) {
|
||||||
|
int64_t start = 0;
|
||||||
|
int64_t end = std::min(start + MAX_BATCH_SIZE, batch_size);
|
||||||
|
|
||||||
|
Tensor query_chunk = query.slice(0, start, end);
|
||||||
|
Tensor key_chunk = key.slice(0, start, end);
|
||||||
|
Tensor value_chunk = value.slice(0, start, end);
|
||||||
|
std::optional<Tensor> bias_chunk;
|
||||||
|
if (attn_bias.has_value()) {
|
||||||
|
bias_chunk = attn_bias.value().slice(0, start, end);
|
||||||
|
}
|
||||||
|
auto [attn, log_sumexp, seed, offset] =
|
||||||
|
process_chunk(query_chunk, key_chunk, value_chunk, bias_chunk);
|
||||||
|
int dim = attn.dim();
|
||||||
|
std::vector<int64_t> sizes;
|
||||||
|
sizes.reserve(dim);
|
||||||
|
sizes.push_back(batch_size);
|
||||||
|
for (int i = 1; i < dim; i++) {
|
||||||
|
sizes.push_back(attn.size(i));
|
||||||
|
}
|
||||||
|
Tensor final_attention = at::empty_strided(sizes, attn.strides(), attn.options());
|
||||||
|
final_attention.slice(0, start, end).copy_(attn);
|
||||||
|
|
||||||
|
for (start = end; start < batch_size; start += MAX_BATCH_SIZE) {
|
||||||
|
end = std::min(start + MAX_BATCH_SIZE, batch_size);
|
||||||
|
query_chunk = query.slice(0, start, end);
|
||||||
|
key_chunk = key.slice(0, start, end);
|
||||||
|
value_chunk = value.slice(0, start, end);
|
||||||
|
if (attn_bias.has_value()) {
|
||||||
|
bias_chunk = attn_bias.value().slice(0, start, end);
|
||||||
|
} else {
|
||||||
|
bias_chunk.reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto [chunk_attn, chunk_log_sumexp, chunk_seed, chunk_offset] =
|
||||||
|
process_chunk(query_chunk, key_chunk, value_chunk, bias_chunk);
|
||||||
|
final_attention.slice(0, start, end).copy_(chunk_attn);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_tuple(std::move(final_attention),
|
||||||
|
std::move(log_sumexp),
|
||||||
|
std::move(seed),
|
||||||
|
std::move(offset));
|
||||||
|
}
|
||||||
|
// when bs is within the allowed size, no need to chunk it
|
||||||
|
else {
|
||||||
|
return process_chunk(query, key, value, attn_bias);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Tensor& value,
|
int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Tensor& value,
|
||||||
|
|
|
||||||
|
|
@ -1898,6 +1898,26 @@ class TestSDPAFailureModes(NNTestCase):
|
||||||
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
|
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
|
||||||
q, k, v, None, 0.0, is_causal=True))
|
q, k, v, None, 0.0, is_causal=True))
|
||||||
|
|
||||||
|
@onlyCUDA
|
||||||
|
def test_mem_eff_attention_fail_with_batch_size_geq_65536(self):
|
||||||
|
query = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
|
||||||
|
key = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
|
||||||
|
value = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
|
||||||
|
with sdpa_kernel(backends=SDPBackend.EFFICIENT_ATTENTION):
|
||||||
|
out = F.scaled_dot_product_attention(query, key, value)
|
||||||
|
out_cpu = F.scaled_dot_product_attention(query.cpu(), key.cpu(), value.cpu())
|
||||||
|
self.assertEqual(out, out_cpu, atol=1e-3, rtol=1e-4)
|
||||||
|
|
||||||
|
@onlyCUDA
|
||||||
|
def test_mem_eff_attention_fail_with_batch_size_geq_65536_error(self):
|
||||||
|
query = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
|
||||||
|
key = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
|
||||||
|
value = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
|
||||||
|
error_str = (r"Efficient attention cannot produce valid seed, "
|
||||||
|
r"logsumexp and offset outputs when the batch size exceeds \(65535\)\.")
|
||||||
|
with self.assertRaisesRegex(RuntimeError, error_str):
|
||||||
|
torch._scaled_dot_product_efficient_attention(query, key, value, attn_bias=None, compute_log_sumexp=True)
|
||||||
|
|
||||||
def _get_block_size_n(device, head_dim, is_dropout, is_causal):
|
def _get_block_size_n(device, head_dim, is_dropout, is_causal):
|
||||||
# This should match the block sizes in the CUDA kernel
|
# This should match the block sizes in the CUDA kernel
|
||||||
assert head_dim <= 256
|
assert head_dim <= 256
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user