mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[CUDA] Fixes for backwards in memefficient attn for large tensors (#154663)
followup to #154029. @ngimel Backwards had the same problem as well so this PR fixes it and adds support for logsumexp computation in the forward pass. Pull Request resolved: https://github.com/pytorch/pytorch/pull/154663 Approved by: https://github.com/ngimel
This commit is contained in:
parent
d89d213118
commit
d6edefefbf
|
|
@ -968,8 +968,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti
|
|||
int64_t batch_size = query.size(0);
|
||||
|
||||
if (batch_size > MAX_BATCH_SIZE) {
|
||||
TORCH_CHECK(!compute_log_sumexp && (dropout_p == 0.0),
|
||||
"Efficient attention cannot produce valid seed, logsumexp and offset outputs when "
|
||||
TORCH_CHECK(dropout_p == 0.0,
|
||||
"Efficient attention cannot produce valid seed and offset outputs when "
|
||||
"the batch size exceeds (", MAX_BATCH_SIZE, ").");
|
||||
}
|
||||
auto process_chunk = [&](const Tensor& q_chunk,
|
||||
|
|
@ -1030,6 +1030,17 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti
|
|||
}
|
||||
Tensor final_attention = at::empty_strided(sizes, attn.strides(), attn.options());
|
||||
final_attention.slice(0, start, end).copy_(attn);
|
||||
Tensor final_log_sumexp;
|
||||
if (compute_log_sumexp && log_sumexp.numel() > 0) {
|
||||
std::vector<int64_t> lse_sizes;
|
||||
lse_sizes.reserve(log_sumexp.dim());
|
||||
lse_sizes.push_back(batch_size);
|
||||
for (int i = 1; i < log_sumexp.dim(); i++) {
|
||||
lse_sizes.push_back(log_sumexp.size(i));
|
||||
}
|
||||
final_log_sumexp = at::empty(std::move(lse_sizes), log_sumexp.options());
|
||||
final_log_sumexp.slice(0, start, end).copy_(log_sumexp);
|
||||
}
|
||||
|
||||
for (start = end; start < batch_size; start += MAX_BATCH_SIZE) {
|
||||
end = std::min(start + MAX_BATCH_SIZE, batch_size);
|
||||
|
|
@ -1045,10 +1056,13 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti
|
|||
auto [chunk_attn, chunk_log_sumexp, chunk_seed, chunk_offset] =
|
||||
process_chunk(query_chunk, key_chunk, value_chunk, bias_chunk);
|
||||
final_attention.slice(0, start, end).copy_(chunk_attn);
|
||||
if (compute_log_sumexp && chunk_log_sumexp.numel() > 0) {
|
||||
final_log_sumexp.slice(0, start, end).copy_(chunk_log_sumexp);
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(std::move(final_attention),
|
||||
std::move(log_sumexp),
|
||||
std::move(final_log_sumexp),
|
||||
std::move(seed),
|
||||
std::move(offset));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@
|
|||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/zeros_like.h>
|
||||
#include <ATen/ops/empty_strided.h>
|
||||
#include <ATen/ops/_flash_attention_backward.h>
|
||||
#include <ATen/ops/_flash_attention_backward_native.h>
|
||||
#include <ATen/ops/_efficient_attention_backward.h>
|
||||
|
|
@ -905,40 +907,56 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_e
|
|||
if (!grad_out_.defined()) {
|
||||
return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{});
|
||||
}
|
||||
auto grad_out = grad_out_.transpose(1, 2);
|
||||
auto out_t = out.transpose(1, 2);
|
||||
auto q_t = query.transpose(1, 2);
|
||||
auto k_t = key.transpose(1, 2);
|
||||
auto v_t = value.transpose(1, 2);
|
||||
constexpr int64_t MAX_BATCH_SIZE = (1LL << 16) - 1;
|
||||
int64_t batch_size = query.size(0);
|
||||
|
||||
if (batch_size > MAX_BATCH_SIZE) {
|
||||
TORCH_CHECK(dropout_p == 0.0,
|
||||
"Efficient attention backward cannot handle dropout when "
|
||||
"the batch size exceeds (", MAX_BATCH_SIZE, ").");
|
||||
}
|
||||
auto grad_out_t = grad_out_.transpose(1, 2);
|
||||
auto query_t = query.transpose(1, 2);
|
||||
auto key_t = key.transpose(1, 2);
|
||||
auto value_t = value.transpose(1, 2);
|
||||
auto out_t = out.transpose(1, 2);
|
||||
|
||||
auto process_chunk = [&](const Tensor& grad_out_chunk,
|
||||
const Tensor& query_chunk,
|
||||
const Tensor& key_chunk,
|
||||
const Tensor& value_chunk,
|
||||
const std::optional<Tensor>& attn_bias_chunk,
|
||||
const Tensor& out_chunk,
|
||||
const Tensor& logsumexp_chunk)
|
||||
-> std::tuple<Tensor, Tensor, Tensor, Tensor> {
|
||||
// This is needed because SaveVariable automatically converts
|
||||
// std::optional to undefined tensor
|
||||
std::optional<Tensor> kernel_bias;
|
||||
if (attn_bias.defined()) {
|
||||
kernel_bias = attn_bias;
|
||||
if (attn_bias_chunk.has_value() && attn_bias_chunk.value().defined()) {
|
||||
kernel_bias = attn_bias_chunk.value();
|
||||
}
|
||||
// Will add with signauter changes for dropout and bias
|
||||
// We are only handling Dense inputs, but this should be passed
|
||||
// from forward to backward
|
||||
int64_t max_seqlen_q = q_t.size(1);
|
||||
int64_t max_seqlen_k = k_t.size(1);
|
||||
int64_t max_seqlen_q = query_chunk.size(2);
|
||||
int64_t max_seqlen_k = key_chunk.size(2);
|
||||
|
||||
sdp::CustomMaskType custom_mask_type = causal
|
||||
? sdp::CustomMaskType::CausalFromTopLeft
|
||||
: sdp::CustomMaskType::NoCustomMask;
|
||||
auto [grad_q, grad_k, grad_v, grad_bias] =
|
||||
at::_efficient_attention_backward(
|
||||
grad_out,
|
||||
q_t,
|
||||
k_t,
|
||||
v_t,
|
||||
grad_out_chunk,
|
||||
query_chunk,
|
||||
key_chunk,
|
||||
value_chunk,
|
||||
kernel_bias,
|
||||
out_t,
|
||||
out_chunk,
|
||||
std::nullopt,
|
||||
std::nullopt,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
logsumexp,
|
||||
logsumexp_chunk,
|
||||
dropout_p,
|
||||
philox_seed,
|
||||
philox_offset,
|
||||
|
|
@ -947,7 +965,90 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_e
|
|||
scale,
|
||||
std::nullopt); // num_split_keys
|
||||
return std::make_tuple(
|
||||
grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2), grad_bias);
|
||||
grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2), std::move(grad_bias));
|
||||
};
|
||||
|
||||
// process in chunks if batch size exceeds maximum
|
||||
if (batch_size > MAX_BATCH_SIZE) {
|
||||
Tensor final_grad_q, final_grad_k, final_grad_v, final_grad_bias;
|
||||
|
||||
auto create_strided_output = [batch_size](const Tensor& tensor) -> Tensor {
|
||||
if (!tensor.defined()) {
|
||||
return Tensor{};
|
||||
}
|
||||
int dim = tensor.dim();
|
||||
std::vector<int64_t> sizes;
|
||||
sizes.reserve(dim);
|
||||
sizes.push_back(batch_size);
|
||||
for (int i = 1; i < dim; i++) {
|
||||
sizes.push_back(tensor.size(i));
|
||||
}
|
||||
return at::empty_strided(std::move(sizes), tensor.strides(), tensor.options());
|
||||
};
|
||||
|
||||
if (grad_input_mask[0]) {
|
||||
final_grad_q = create_strided_output(query);
|
||||
}
|
||||
|
||||
if (grad_input_mask[1]) {
|
||||
final_grad_k = create_strided_output(key);
|
||||
}
|
||||
|
||||
if (grad_input_mask[2]) {
|
||||
final_grad_v = create_strided_output(value);
|
||||
}
|
||||
if (grad_input_mask[3] && attn_bias.defined()) {
|
||||
final_grad_bias = at::zeros_like(attn_bias);
|
||||
}
|
||||
|
||||
for (int64_t start = 0; start < batch_size; start += MAX_BATCH_SIZE) {
|
||||
int64_t end = std::min(start + MAX_BATCH_SIZE, batch_size);
|
||||
|
||||
Tensor grad_out_chunk = grad_out_t.slice(0, start, end);
|
||||
Tensor query_chunk = query_t.slice(0, start, end);
|
||||
Tensor key_chunk = key_t.slice(0, start, end);
|
||||
Tensor value_chunk = value_t.slice(0, start, end);
|
||||
Tensor attn_bias_chunk;
|
||||
if (attn_bias.defined()) {
|
||||
attn_bias_chunk = attn_bias.slice(0, start, end);
|
||||
} else {
|
||||
attn_bias_chunk.reset();
|
||||
}
|
||||
Tensor out_chunk = out_t.slice(0, start, end);
|
||||
Tensor logsumexp_chunk = logsumexp.numel() > 0 ? logsumexp.slice(0, start, end) : logsumexp;
|
||||
|
||||
auto [chunk_grad_q, chunk_grad_k, chunk_grad_v, chunk_grad_bias] =
|
||||
process_chunk(grad_out_chunk, query_chunk, key_chunk, value_chunk,
|
||||
attn_bias_chunk, out_chunk, logsumexp_chunk);
|
||||
|
||||
if (grad_input_mask[0] && chunk_grad_q.defined()) {
|
||||
final_grad_q.slice(0, start, end).copy_(chunk_grad_q);
|
||||
}
|
||||
if (grad_input_mask[1] && chunk_grad_k.defined()) {
|
||||
final_grad_k.slice(0, start, end).copy_(chunk_grad_k);
|
||||
}
|
||||
if (grad_input_mask[2] && chunk_grad_v.defined()) {
|
||||
final_grad_v.slice(0, start, end).copy_(chunk_grad_v);
|
||||
}
|
||||
if (grad_input_mask[3] && chunk_grad_bias.defined()) {
|
||||
final_grad_bias.add_(chunk_grad_bias);
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(
|
||||
std::move(final_grad_q),
|
||||
std::move(final_grad_k),
|
||||
std::move(final_grad_v),
|
||||
std::move(final_grad_bias));
|
||||
}
|
||||
// when batch size is within allowed size, no chunking needed
|
||||
else {
|
||||
std::optional<Tensor> attn_bias_opt;
|
||||
if (attn_bias.defined()) {
|
||||
attn_bias_opt = attn_bias;
|
||||
}
|
||||
return process_chunk(grad_out_t, query_t, key_t, value_t, attn_bias_opt, out_t, logsumexp);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -1900,23 +1900,36 @@ class TestSDPAFailureModes(NNTestCase):
|
|||
|
||||
@onlyCUDA
|
||||
def test_mem_eff_attention_fail_with_batch_size_geq_65536(self):
|
||||
query = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
|
||||
key = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
|
||||
value = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
|
||||
batch_size = 2**16
|
||||
query = torch.rand([batch_size, 2, 2, 8], device='cuda', dtype=torch.float16, requires_grad=True)
|
||||
key = torch.rand([batch_size, 2, 2, 8], device='cuda', dtype=torch.float16, requires_grad=True)
|
||||
value = torch.rand([batch_size, 2, 2, 8], device='cuda', dtype=torch.float16, requires_grad=True)
|
||||
q_cpu, k_cpu, v_cpu = (query.detach().cpu().requires_grad_(True),
|
||||
key.detach().cpu().requires_grad_(True),
|
||||
value.detach().cpu().requires_grad_(True))
|
||||
with sdpa_kernel(backends=SDPBackend.EFFICIENT_ATTENTION):
|
||||
out = F.scaled_dot_product_attention(query, key, value)
|
||||
out_cpu = F.scaled_dot_product_attention(query.cpu(), key.cpu(), value.cpu())
|
||||
self.assertEqual(out, out_cpu, atol=1e-3, rtol=1e-4)
|
||||
out_cpu = F.scaled_dot_product_attention(q_cpu, k_cpu, v_cpu)
|
||||
grad_out = torch.rand_like(out)
|
||||
out.backward(grad_out)
|
||||
out_cpu.backward(grad_out.cpu())
|
||||
|
||||
self.assertEqual(out, out_cpu, atol=2e-3, rtol=1e-4)
|
||||
self.assertEqual(query.grad, q_cpu.grad, atol=2e-3, rtol=1e-4)
|
||||
self.assertEqual(key.grad, k_cpu.grad, atol=2e-3, rtol=1e-4)
|
||||
self.assertEqual(value.grad, v_cpu.grad, atol=2e-3, rtol=1e-4)
|
||||
|
||||
@onlyCUDA
|
||||
def test_mem_eff_attention_fail_with_batch_size_geq_65536_error(self):
|
||||
query = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
|
||||
key = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
|
||||
value = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
|
||||
error_str = (r"Efficient attention cannot produce valid seed, "
|
||||
r"logsumexp and offset outputs when the batch size exceeds \(65535\)\.")
|
||||
error_str = (r"Efficient attention cannot produce valid seed and offset outputs when "
|
||||
r"the batch size exceeds \(65535\)\.")
|
||||
with self.assertRaisesRegex(RuntimeError, error_str):
|
||||
torch._scaled_dot_product_efficient_attention(query, key, value, attn_bias=None, compute_log_sumexp=True)
|
||||
torch._scaled_dot_product_efficient_attention(query, key, value,
|
||||
attn_bias=None, compute_log_sumexp=True,
|
||||
dropout_p=0.01)
|
||||
|
||||
def _get_block_size_n(device, head_dim, is_dropout, is_causal):
|
||||
# This should match the block sizes in the CUDA kernel
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user