[CUDA] Fixes for backwards in memefficient attn for large tensors (#154663)

followup to #154029.

@ngimel Backwards had the same problem as well so this PR fixes it and adds support for logsumexp computation in the forward pass.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154663
Approved by: https://github.com/ngimel
This commit is contained in:
Isalia20 2025-05-30 19:30:02 +00:00 committed by PyTorch MergeBot
parent d89d213118
commit d6edefefbf
3 changed files with 155 additions and 27 deletions

View File

@ -968,8 +968,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti
int64_t batch_size = query.size(0);
if (batch_size > MAX_BATCH_SIZE) {
TORCH_CHECK(!compute_log_sumexp && (dropout_p == 0.0),
"Efficient attention cannot produce valid seed, logsumexp and offset outputs when "
TORCH_CHECK(dropout_p == 0.0,
"Efficient attention cannot produce valid seed and offset outputs when "
"the batch size exceeds (", MAX_BATCH_SIZE, ").");
}
auto process_chunk = [&](const Tensor& q_chunk,
@ -1030,6 +1030,17 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti
}
Tensor final_attention = at::empty_strided(sizes, attn.strides(), attn.options());
final_attention.slice(0, start, end).copy_(attn);
Tensor final_log_sumexp;
if (compute_log_sumexp && log_sumexp.numel() > 0) {
std::vector<int64_t> lse_sizes;
lse_sizes.reserve(log_sumexp.dim());
lse_sizes.push_back(batch_size);
for (int i = 1; i < log_sumexp.dim(); i++) {
lse_sizes.push_back(log_sumexp.size(i));
}
final_log_sumexp = at::empty(std::move(lse_sizes), log_sumexp.options());
final_log_sumexp.slice(0, start, end).copy_(log_sumexp);
}
for (start = end; start < batch_size; start += MAX_BATCH_SIZE) {
end = std::min(start + MAX_BATCH_SIZE, batch_size);
@ -1045,10 +1056,13 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti
auto [chunk_attn, chunk_log_sumexp, chunk_seed, chunk_offset] =
process_chunk(query_chunk, key_chunk, value_chunk, bias_chunk);
final_attention.slice(0, start, end).copy_(chunk_attn);
if (compute_log_sumexp && chunk_log_sumexp.numel() > 0) {
final_log_sumexp.slice(0, start, end).copy_(chunk_log_sumexp);
}
}
return std::make_tuple(std::move(final_attention),
std::move(log_sumexp),
std::move(final_log_sumexp),
std::move(seed),
std::move(offset));
}

View File

@ -24,6 +24,8 @@
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/zeros_like.h>
#include <ATen/ops/empty_strided.h>
#include <ATen/ops/_flash_attention_backward.h>
#include <ATen/ops/_flash_attention_backward_native.h>
#include <ATen/ops/_efficient_attention_backward.h>
@ -905,40 +907,56 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_e
if (!grad_out_.defined()) {
return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{});
}
auto grad_out = grad_out_.transpose(1, 2);
auto out_t = out.transpose(1, 2);
auto q_t = query.transpose(1, 2);
auto k_t = key.transpose(1, 2);
auto v_t = value.transpose(1, 2);
constexpr int64_t MAX_BATCH_SIZE = (1LL << 16) - 1;
int64_t batch_size = query.size(0);
if (batch_size > MAX_BATCH_SIZE) {
TORCH_CHECK(dropout_p == 0.0,
"Efficient attention backward cannot handle dropout when "
"the batch size exceeds (", MAX_BATCH_SIZE, ").");
}
auto grad_out_t = grad_out_.transpose(1, 2);
auto query_t = query.transpose(1, 2);
auto key_t = key.transpose(1, 2);
auto value_t = value.transpose(1, 2);
auto out_t = out.transpose(1, 2);
auto process_chunk = [&](const Tensor& grad_out_chunk,
const Tensor& query_chunk,
const Tensor& key_chunk,
const Tensor& value_chunk,
const std::optional<Tensor>& attn_bias_chunk,
const Tensor& out_chunk,
const Tensor& logsumexp_chunk)
-> std::tuple<Tensor, Tensor, Tensor, Tensor> {
// This is needed because SaveVariable automatically converts
// std::optional to undefined tensor
std::optional<Tensor> kernel_bias;
if (attn_bias.defined()) {
kernel_bias = attn_bias;
if (attn_bias_chunk.has_value() && attn_bias_chunk.value().defined()) {
kernel_bias = attn_bias_chunk.value();
}
// Will add with signauter changes for dropout and bias
// We are only handling Dense inputs, but this should be passed
// from forward to backward
int64_t max_seqlen_q = q_t.size(1);
int64_t max_seqlen_k = k_t.size(1);
int64_t max_seqlen_q = query_chunk.size(2);
int64_t max_seqlen_k = key_chunk.size(2);
sdp::CustomMaskType custom_mask_type = causal
? sdp::CustomMaskType::CausalFromTopLeft
: sdp::CustomMaskType::NoCustomMask;
auto [grad_q, grad_k, grad_v, grad_bias] =
at::_efficient_attention_backward(
grad_out,
q_t,
k_t,
v_t,
grad_out_chunk,
query_chunk,
key_chunk,
value_chunk,
kernel_bias,
out_t,
out_chunk,
std::nullopt,
std::nullopt,
max_seqlen_q,
max_seqlen_k,
logsumexp,
logsumexp_chunk,
dropout_p,
philox_seed,
philox_offset,
@ -947,7 +965,90 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_e
scale,
std::nullopt); // num_split_keys
return std::make_tuple(
grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2), grad_bias);
grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2), std::move(grad_bias));
};
// process in chunks if batch size exceeds maximum
if (batch_size > MAX_BATCH_SIZE) {
Tensor final_grad_q, final_grad_k, final_grad_v, final_grad_bias;
auto create_strided_output = [batch_size](const Tensor& tensor) -> Tensor {
if (!tensor.defined()) {
return Tensor{};
}
int dim = tensor.dim();
std::vector<int64_t> sizes;
sizes.reserve(dim);
sizes.push_back(batch_size);
for (int i = 1; i < dim; i++) {
sizes.push_back(tensor.size(i));
}
return at::empty_strided(std::move(sizes), tensor.strides(), tensor.options());
};
if (grad_input_mask[0]) {
final_grad_q = create_strided_output(query);
}
if (grad_input_mask[1]) {
final_grad_k = create_strided_output(key);
}
if (grad_input_mask[2]) {
final_grad_v = create_strided_output(value);
}
if (grad_input_mask[3] && attn_bias.defined()) {
final_grad_bias = at::zeros_like(attn_bias);
}
for (int64_t start = 0; start < batch_size; start += MAX_BATCH_SIZE) {
int64_t end = std::min(start + MAX_BATCH_SIZE, batch_size);
Tensor grad_out_chunk = grad_out_t.slice(0, start, end);
Tensor query_chunk = query_t.slice(0, start, end);
Tensor key_chunk = key_t.slice(0, start, end);
Tensor value_chunk = value_t.slice(0, start, end);
Tensor attn_bias_chunk;
if (attn_bias.defined()) {
attn_bias_chunk = attn_bias.slice(0, start, end);
} else {
attn_bias_chunk.reset();
}
Tensor out_chunk = out_t.slice(0, start, end);
Tensor logsumexp_chunk = logsumexp.numel() > 0 ? logsumexp.slice(0, start, end) : logsumexp;
auto [chunk_grad_q, chunk_grad_k, chunk_grad_v, chunk_grad_bias] =
process_chunk(grad_out_chunk, query_chunk, key_chunk, value_chunk,
attn_bias_chunk, out_chunk, logsumexp_chunk);
if (grad_input_mask[0] && chunk_grad_q.defined()) {
final_grad_q.slice(0, start, end).copy_(chunk_grad_q);
}
if (grad_input_mask[1] && chunk_grad_k.defined()) {
final_grad_k.slice(0, start, end).copy_(chunk_grad_k);
}
if (grad_input_mask[2] && chunk_grad_v.defined()) {
final_grad_v.slice(0, start, end).copy_(chunk_grad_v);
}
if (grad_input_mask[3] && chunk_grad_bias.defined()) {
final_grad_bias.add_(chunk_grad_bias);
}
}
return std::make_tuple(
std::move(final_grad_q),
std::move(final_grad_k),
std::move(final_grad_v),
std::move(final_grad_bias));
}
// when batch size is within allowed size, no chunking needed
else {
std::optional<Tensor> attn_bias_opt;
if (attn_bias.defined()) {
attn_bias_opt = attn_bias;
}
return process_chunk(grad_out_t, query_t, key_t, value_t, attn_bias_opt, out_t, logsumexp);
}
}
} // namespace at::native

View File

@ -1900,23 +1900,36 @@ class TestSDPAFailureModes(NNTestCase):
@onlyCUDA
def test_mem_eff_attention_fail_with_batch_size_geq_65536(self):
query = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
key = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
value = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
batch_size = 2**16
query = torch.rand([batch_size, 2, 2, 8], device='cuda', dtype=torch.float16, requires_grad=True)
key = torch.rand([batch_size, 2, 2, 8], device='cuda', dtype=torch.float16, requires_grad=True)
value = torch.rand([batch_size, 2, 2, 8], device='cuda', dtype=torch.float16, requires_grad=True)
q_cpu, k_cpu, v_cpu = (query.detach().cpu().requires_grad_(True),
key.detach().cpu().requires_grad_(True),
value.detach().cpu().requires_grad_(True))
with sdpa_kernel(backends=SDPBackend.EFFICIENT_ATTENTION):
out = F.scaled_dot_product_attention(query, key, value)
out_cpu = F.scaled_dot_product_attention(query.cpu(), key.cpu(), value.cpu())
self.assertEqual(out, out_cpu, atol=1e-3, rtol=1e-4)
out_cpu = F.scaled_dot_product_attention(q_cpu, k_cpu, v_cpu)
grad_out = torch.rand_like(out)
out.backward(grad_out)
out_cpu.backward(grad_out.cpu())
self.assertEqual(out, out_cpu, atol=2e-3, rtol=1e-4)
self.assertEqual(query.grad, q_cpu.grad, atol=2e-3, rtol=1e-4)
self.assertEqual(key.grad, k_cpu.grad, atol=2e-3, rtol=1e-4)
self.assertEqual(value.grad, v_cpu.grad, atol=2e-3, rtol=1e-4)
@onlyCUDA
def test_mem_eff_attention_fail_with_batch_size_geq_65536_error(self):
query = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
key = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
value = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
error_str = (r"Efficient attention cannot produce valid seed, "
r"logsumexp and offset outputs when the batch size exceeds \(65535\)\.")
error_str = (r"Efficient attention cannot produce valid seed and offset outputs when "
r"the batch size exceeds \(65535\)\.")
with self.assertRaisesRegex(RuntimeError, error_str):
torch._scaled_dot_product_efficient_attention(query, key, value, attn_bias=None, compute_log_sumexp=True)
torch._scaled_dot_product_efficient_attention(query, key, value,
attn_bias=None, compute_log_sumexp=True,
dropout_p=0.01)
def _get_block_size_n(device, head_dim, is_dropout, is_causal):
# This should match the block sizes in the CUDA kernel