mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ROCm] CK Memory-Efficient Attention (attention bias support) (#147778)
Implements CK as the backend for memory efficient attention with a couple caveats:
- Still enabled via `torch.backends.cuda.preferred_rocm_fa_library("ck")
- Does NOT support Nested Tensors
Using the mem_eff path allows us to use attention bias with a CK sdpa backend
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147778
Approved by: https://github.com/houseroad
This commit is contained in:
parent
a1cb67b69e
commit
4d10da731b
|
|
@ -88,6 +88,7 @@
|
|||
#include <ATen/native/transformers/hip/aotriton_adapter.h>
|
||||
#include <aotriton/flash.h>
|
||||
#include <aotriton/runtime.h>
|
||||
#include <ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h>
|
||||
#endif
|
||||
#endif
|
||||
|
||||
|
|
@ -1243,104 +1244,146 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
|
|||
|
||||
#ifdef USE_ROCM
|
||||
// ROCM Implementation
|
||||
auto ret = aotriton::v2::flash::check_gpu(stream);
|
||||
if (hipSuccess != ret) {
|
||||
TORCH_CHECK(false,
|
||||
"[AOTriton] Accelerated SDPA only supports MI200/MI300X/7900XTX/9070XT GPUs"
|
||||
" (gfx90a/gfx942/gfx1100/gfx1201)")
|
||||
}
|
||||
|
||||
// AOTriton may accept aligned on logsumexp tensor in the future for better
|
||||
// performance, but for now it requires compact logsumexp tensor, even if
|
||||
// compute_logsumexp is false
|
||||
constexpr int kAlignLSE = 1;
|
||||
// Need this in both aot and CK case
|
||||
const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float();
|
||||
res = at::empty({B, M, num_heads, Kv}, query.options());
|
||||
at::Tensor softmax_lse;
|
||||
logsumexp = at::empty(
|
||||
|
||||
if(at::globalContext().getROCmFAPreferredBackend() ==
|
||||
at::ROCmFABackend::Ck) {
|
||||
|
||||
#if defined(USE_CK_FLASH_ATTENTION)
|
||||
std::optional<Tensor> out(res);
|
||||
std::optional<Tensor> seqused_k = std::nullopt;
|
||||
std::optional<Tensor> alibi_slopes = std::nullopt;
|
||||
auto
|
||||
[out_,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
lse,
|
||||
seed_t,
|
||||
offset_t,
|
||||
p] =
|
||||
pytorch_flash::mem_eff_forward_ck(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
dropout_p,
|
||||
false, // return dropout_randval
|
||||
custom_mask_type == 0 ? false : true, // is_causal
|
||||
softmax_scale,
|
||||
bias,
|
||||
out,
|
||||
std::nullopt, // cu_seqlens_q
|
||||
std::nullopt, // cu_seqlens_k
|
||||
seqstart_q,
|
||||
seqstart_k,
|
||||
std::nullopt, // gen_
|
||||
seqused_k); // seqused_k_
|
||||
|
||||
logsumexp = lse;
|
||||
#else
|
||||
TORCH_CHECK(false, "Attempting to use CK mem_eff_forward backend in a build that has not built CK");
|
||||
#endif
|
||||
} else { // use aotriton
|
||||
auto ret = aotriton::v2::flash::check_gpu(stream);
|
||||
if (hipSuccess != ret) {
|
||||
TORCH_CHECK(false,
|
||||
"[AOTriton] Accelerated SDPA only supports MI200/MI300X/Navi31 GPUs"
|
||||
" (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)")
|
||||
}
|
||||
|
||||
// AOTriton may accept aligned on logsumexp tensor in the future for better
|
||||
// performance, but for now it requires compact logsumexp tensor, even if
|
||||
// compute_logsumexp is false
|
||||
constexpr int kAlignLSE = 1;
|
||||
res = at::empty({B, M, num_heads, Kv}, query.options());
|
||||
at::Tensor softmax_lse;
|
||||
logsumexp = at::empty(
|
||||
{ B, num_heads, compute_logsumexp ? max_seqlen_q : 0},
|
||||
query.options().dtype(at::ScalarType::Float));
|
||||
if (compute_logsumexp) {
|
||||
softmax_lse = logsumexp.view({B * num_heads, max_seqlen_q});
|
||||
}
|
||||
at::Tensor q_t = query.transpose(1, 2);
|
||||
at::Tensor k_t = key.transpose(1, 2);
|
||||
at::Tensor v_t = value.transpose(1, 2);
|
||||
at::Tensor output_t = res.transpose(1, 2);
|
||||
bool is_causal;
|
||||
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
|
||||
is_causal = true;
|
||||
} else if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
|
||||
is_causal = false;
|
||||
} else {
|
||||
TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now");
|
||||
}
|
||||
if (compute_logsumexp) {
|
||||
softmax_lse = logsumexp.view({B * num_heads, max_seqlen_q});
|
||||
}
|
||||
at::Tensor q_t = query.transpose(1, 2);
|
||||
at::Tensor k_t = key.transpose(1, 2);
|
||||
at::Tensor v_t = value.transpose(1, 2);
|
||||
at::Tensor output_t = res.transpose(1, 2);
|
||||
bool is_causal;
|
||||
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
|
||||
is_causal = true;
|
||||
} else if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
|
||||
is_causal = false;
|
||||
} else {
|
||||
TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now");
|
||||
}
|
||||
|
||||
const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float();
|
||||
at::Tensor atomic_counter;
|
||||
if (is_causal) {
|
||||
atomic_counter = at::zeros({1}, query.options().dtype(at::kInt));
|
||||
}
|
||||
|
||||
at::Tensor atomic_counter;
|
||||
if (is_causal) {
|
||||
atomic_counter = at::zeros({1}, query.options().dtype(at::kInt));
|
||||
}
|
||||
|
||||
using aotriton::v2::flash::attn_fwd;
|
||||
using aotriton::v2::flash::attn_fwd_compact_varlen;
|
||||
using sdp::aotriton_adapter::mk_aotensor;
|
||||
using sdp::aotriton_adapter::mk_aoscalartensor;
|
||||
using sdp::aotriton_adapter::mk_philoxtensor;
|
||||
using sdp::aotriton_adapter::mk_atomictensor;
|
||||
aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, aotriton::DType::kFloat16);
|
||||
aotriton::TensorView<2> empty_t2(0, {0, 0}, {0, 0}, aotriton::DType::kFloat32);
|
||||
at::Tensor softmax_fa_t = at::empty({ 0, 0, 0, 0 }, query.options());
|
||||
const bool use_philox_state = in_capture_stream;
|
||||
auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t);
|
||||
auto offset1 = use_philox_state ? mk_philoxtensor(philox_state.offset_.ptr) : mk_aoscalartensor(offset_t);
|
||||
auto offset2 = use_philox_state ? philox_state.offset_intragraph_ : 0;
|
||||
auto seed_output = mk_philoxtensor(use_philox_state ? seed_t.data_ptr<int64_t>() : nullptr);
|
||||
auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr<int64_t>() : nullptr);
|
||||
auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr<int32_t>() : nullptr);
|
||||
hipError_t err; // TODO: Error handling
|
||||
if (seqstart_q.has_value()) {
|
||||
// varlen aka nested tensor
|
||||
err = attn_fwd_compact_varlen(mk_aotensor(q_t, "q"),
|
||||
mk_aotensor(k_t, "k"),
|
||||
mk_aotensor(v_t, "v"),
|
||||
bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4,
|
||||
mk_aotensor<1>(seqstart_q.value(), "cu_seqlens_q"),
|
||||
mk_aotensor<1>(seqstart_k.value(), "cu_seqlens_k"),
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
softmax_scale,
|
||||
compute_logsumexp ? mk_aotensor<2>(softmax_lse, "M") : empty_t2,
|
||||
mk_aotensor(output_t, "Out"),
|
||||
dropout_p,
|
||||
seed,
|
||||
offset1,
|
||||
offset2,
|
||||
seed_output,
|
||||
offset_output,
|
||||
mk_aotensor(softmax_fa_t, "encoded_softmax"),
|
||||
is_causal,
|
||||
persistent_counter,
|
||||
stream);
|
||||
} else {
|
||||
err = attn_fwd(mk_aotensor(q_t, "q"),
|
||||
mk_aotensor(k_t, "k"),
|
||||
mk_aotensor(v_t, "v"),
|
||||
bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4,
|
||||
softmax_scale,
|
||||
compute_logsumexp ? mk_aotensor<2>(softmax_lse, "M") : empty_t2,
|
||||
mk_aotensor(output_t, "Out"),
|
||||
dropout_p,
|
||||
seed,
|
||||
offset1,
|
||||
offset2,
|
||||
seed_output,
|
||||
offset_output,
|
||||
mk_aotensor(softmax_fa_t, "encoded_softmax"),
|
||||
is_causal,
|
||||
persistent_counter,
|
||||
stream);
|
||||
}
|
||||
using aotriton::v2::flash::attn_fwd;
|
||||
using aotriton::v2::flash::attn_fwd_compact_varlen;
|
||||
using sdp::aotriton_adapter::mk_aotensor;
|
||||
using sdp::aotriton_adapter::mk_aoscalartensor;
|
||||
using sdp::aotriton_adapter::mk_philoxtensor;
|
||||
using sdp::aotriton_adapter::mk_atomictensor;
|
||||
aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, aotriton::DType::kFloat16);
|
||||
aotriton::TensorView<2> empty_t2(0, {0, 0}, {0, 0}, aotriton::DType::kFloat32);
|
||||
at::Tensor softmax_fa_t = at::empty({ 0, 0, 0, 0 }, query.options());
|
||||
const bool use_philox_state = in_capture_stream;
|
||||
auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t);
|
||||
auto offset1 = use_philox_state ? mk_philoxtensor(philox_state.offset_.ptr) : mk_aoscalartensor(offset_t);
|
||||
auto offset2 = use_philox_state ? philox_state.offset_intragraph_ : 0;
|
||||
auto seed_output = mk_philoxtensor(use_philox_state ? seed_t.data_ptr<int64_t>() : nullptr);
|
||||
auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr<int64_t>() : nullptr);
|
||||
auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr<int32_t>() : nullptr);
|
||||
hipError_t err; // TODO: Error handling
|
||||
if (seqstart_q.has_value()) {
|
||||
// varlen aka nested tensor
|
||||
err = attn_fwd_compact_varlen(mk_aotensor(q_t, "q"),
|
||||
mk_aotensor(k_t, "k"),
|
||||
mk_aotensor(v_t, "v"),
|
||||
bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4,
|
||||
mk_aotensor<1>(seqstart_q.value(), "cu_seqlens_q"),
|
||||
mk_aotensor<1>(seqstart_k.value(), "cu_seqlens_k"),
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
softmax_scale,
|
||||
compute_logsumexp ? mk_aotensor<2>(softmax_lse, "M") : empty_t2,
|
||||
mk_aotensor(output_t, "Out"),
|
||||
dropout_p,
|
||||
seed,
|
||||
offset1,
|
||||
offset2,
|
||||
seed_output,
|
||||
offset_output,
|
||||
mk_aotensor(softmax_fa_t, "encoded_softmax"),
|
||||
is_causal,
|
||||
persistent_counter,
|
||||
stream);
|
||||
} else {
|
||||
err = attn_fwd(mk_aotensor(q_t, "q"),
|
||||
mk_aotensor(k_t, "k"),
|
||||
mk_aotensor(v_t, "v"),
|
||||
bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4,
|
||||
softmax_scale,
|
||||
compute_logsumexp ? mk_aotensor<2>(softmax_lse, "M") : empty_t2,
|
||||
mk_aotensor(output_t, "Out"),
|
||||
dropout_p,
|
||||
seed,
|
||||
offset1,
|
||||
offset2,
|
||||
seed_output,
|
||||
offset_output,
|
||||
mk_aotensor(softmax_fa_t, "encoded_softmax"),
|
||||
is_causal,
|
||||
persistent_counter,
|
||||
stream);
|
||||
}
|
||||
} // CK BACKEND
|
||||
#else
|
||||
// CUDA Implementation
|
||||
cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index());
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@
|
|||
#include <ATen/native/transformers/hip/aotriton_adapter.h>
|
||||
#include <aotriton/flash.h>
|
||||
#include <aotriton/runtime.h>
|
||||
#include <ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h>
|
||||
#endif
|
||||
#endif
|
||||
|
||||
|
|
@ -412,94 +413,133 @@ _efficient_attention_backward(
|
|||
|
||||
#ifdef USE_ROCM
|
||||
// ROCM Implementation
|
||||
TORCH_CHECK(!num_splits_key.has_value(),
|
||||
if(at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck)
|
||||
{
|
||||
#if defined(USE_CK_FLASH_ATTENTION)
|
||||
const auto my_softmax_scale = sdp::calculate_scale(query, scale).expect_float();
|
||||
// Store grad_bias in optional
|
||||
std::optional<at::Tensor> opt_grad_bias = grad_bias;
|
||||
auto
|
||||
[dQ,
|
||||
dK,
|
||||
dV,
|
||||
dBias] =
|
||||
pytorch_flash::mem_eff_backward_ck(
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
logsumexp,
|
||||
grad_q,
|
||||
grad_k,
|
||||
grad_v,
|
||||
bias,
|
||||
bias_requires_grad,
|
||||
opt_grad_bias,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
float(dropout_p),
|
||||
my_softmax_scale,
|
||||
custom_mask_type == 0 ? false : true, // is_causal
|
||||
false, // deterministic
|
||||
false, // zero_tensors
|
||||
philox_seed,
|
||||
philox_offset);
|
||||
grad_bias = dBias;
|
||||
#else
|
||||
TORCH_CHECK(false, "Attempting to use CK mem_eff_backward backend in a build that has not built CK");
|
||||
#endif
|
||||
} else {
|
||||
TORCH_CHECK(!num_splits_key.has_value(),
|
||||
"ROCM does not support num_split_keys in _efficient_attention_forward");
|
||||
TORCH_CHECK(!window_size.has_value(),
|
||||
TORCH_CHECK(!window_size.has_value(),
|
||||
"ROCM does not support window_size in _efficient_attention_forward");
|
||||
auto ret = aotriton::v2::flash::check_gpu(stream);
|
||||
if (hipSuccess != ret) {
|
||||
TORCH_CHECK(false,
|
||||
auto ret = aotriton::v2::flash::check_gpu(stream);
|
||||
if (hipSuccess != ret) {
|
||||
TORCH_CHECK(false,
|
||||
"[AOTriton] Accelerated SDPA only supports MI200/MI300X/7900XTX/9070XT GPUs"
|
||||
" (gfx90a/gfx942/gfx1100/gfx1201)")
|
||||
}
|
||||
const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float();
|
||||
bool is_causal;
|
||||
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
|
||||
is_causal = true;
|
||||
} else if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
|
||||
is_causal = false;
|
||||
} else {
|
||||
TORCH_CHECK(false, "[_efficient_attention_backward] Unsupported mask type in AOTriton, for now");
|
||||
}
|
||||
at::Tensor q_t = query.permute({0,2,1,3});
|
||||
at::Tensor k_t = key.permute({0,2,1,3});
|
||||
at::Tensor v_t = value.permute({0,2,1,3});
|
||||
at::Tensor out_t = out.permute({0,2,1,3});
|
||||
at::Tensor dq_t = grad_q.permute({0,2,1,3});
|
||||
at::Tensor dk_t = grad_k.permute({0,2,1,3});
|
||||
at::Tensor dv_t = grad_v.permute({0,2,1,3});
|
||||
at::Tensor dout_t = grad_out.permute({0,2,1,3});
|
||||
at::Tensor softmax_lse = logsumexp.view({B * nH, max_seqlen_q});
|
||||
|
||||
hipError_t err;
|
||||
using aotriton::v2::flash::attn_bwd;
|
||||
using aotriton::v2::flash::attn_bwd_fused;
|
||||
using aotriton::v2::flash::attn_bwd_compact_varlen;
|
||||
using sdp::aotriton_adapter::mk_aotensor;
|
||||
using sdp::aotriton_adapter::mk_aoscalartensor;
|
||||
using sdp::aotriton_adapter::cast_dtype;
|
||||
aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, cast_dtype(query.dtype()));
|
||||
if (cu_seqlens_q.has_value()) {
|
||||
at::Tensor delta = at::empty_like(softmax_lse).contiguous();
|
||||
// varlen aka Nested tensor
|
||||
err = attn_bwd_compact_varlen(mk_aotensor(q_t, "q"),
|
||||
mk_aotensor(k_t, "k"),
|
||||
mk_aotensor(v_t, "v"),
|
||||
mk_aotensor<1>(cu_seqlens_q.value(), "cu_seqlens_q"),
|
||||
mk_aotensor<1>(cu_seqlens_k.value(), "cu_seqlens_k"),
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4,
|
||||
softmax_scale,
|
||||
mk_aotensor(out_t, "out"),
|
||||
mk_aotensor(dout_t, "dout"),
|
||||
mk_aotensor(dq_t, "dq"),
|
||||
mk_aotensor(dk_t, "dk"),
|
||||
mk_aotensor(dv_t, "dv"),
|
||||
bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4,
|
||||
mk_aotensor<2>(softmax_lse, "L"),
|
||||
mk_aotensor<2>(delta, "delta"),
|
||||
float(dropout_p),
|
||||
mk_aoscalartensor(philox_seed),
|
||||
mk_aoscalartensor(philox_offset),
|
||||
0,
|
||||
is_causal,
|
||||
stream);
|
||||
} else {
|
||||
auto d_head = Kv;
|
||||
bool use_fused_bwd = d_head <= 192 && d_head * max_seqlen_q < 64 * 512;
|
||||
if (use_fused_bwd) {
|
||||
err = attn_bwd_fused(mk_aotensor(q_t, "q"),
|
||||
mk_aotensor(k_t, "k"),
|
||||
mk_aotensor(v_t, "v"),
|
||||
bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4,
|
||||
softmax_scale,
|
||||
mk_aotensor(out_t, "out"),
|
||||
mk_aotensor(dout_t, "dout"),
|
||||
mk_aotensor(dq_t, "dq"),
|
||||
mk_aotensor(dk_t, "dk"),
|
||||
mk_aotensor(dv_t, "dv"),
|
||||
bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4,
|
||||
mk_aotensor<2>(softmax_lse, "L"),
|
||||
float(dropout_p),
|
||||
mk_aoscalartensor(philox_seed),
|
||||
mk_aoscalartensor(philox_offset),
|
||||
0,
|
||||
is_causal,
|
||||
stream);
|
||||
}
|
||||
const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float();
|
||||
bool is_causal;
|
||||
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
|
||||
is_causal = true;
|
||||
} else if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
|
||||
is_causal = false;
|
||||
} else {
|
||||
TORCH_CHECK(false, "[_efficient_attention_backward] Unsupported mask type in AOTriton, for now");
|
||||
}
|
||||
at::Tensor q_t = query.permute({0,2,1,3});
|
||||
at::Tensor k_t = key.permute({0,2,1,3});
|
||||
at::Tensor v_t = value.permute({0,2,1,3});
|
||||
at::Tensor out_t = out.permute({0,2,1,3});
|
||||
at::Tensor dq_t = grad_q.permute({0,2,1,3});
|
||||
at::Tensor dk_t = grad_k.permute({0,2,1,3});
|
||||
at::Tensor dv_t = grad_v.permute({0,2,1,3});
|
||||
at::Tensor dout_t = grad_out.permute({0,2,1,3});
|
||||
at::Tensor softmax_lse = logsumexp.view({B * nH, max_seqlen_q});
|
||||
hipError_t err;
|
||||
using aotriton::v2::flash::attn_bwd;
|
||||
using aotriton::v2::flash::attn_bwd_fused;
|
||||
using aotriton::v2::flash::attn_bwd_compact_varlen;
|
||||
using sdp::aotriton_adapter::mk_aotensor;
|
||||
using sdp::aotriton_adapter::mk_aoscalartensor;
|
||||
using sdp::aotriton_adapter::cast_dtype;
|
||||
aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, cast_dtype(query.dtype()));
|
||||
if (cu_seqlens_q.has_value()) {
|
||||
at::Tensor delta = at::empty_like(softmax_lse).contiguous();
|
||||
err = attn_bwd(mk_aotensor(q_t, "q"),
|
||||
// varlen aka Nested tensor
|
||||
err = attn_bwd_compact_varlen(mk_aotensor(q_t, "q"),
|
||||
mk_aotensor(k_t, "k"),
|
||||
mk_aotensor(v_t, "v"),
|
||||
mk_aotensor<1>(cu_seqlens_q.value(), "cu_seqlens_q"),
|
||||
mk_aotensor<1>(cu_seqlens_k.value(), "cu_seqlens_k"),
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4,
|
||||
softmax_scale,
|
||||
mk_aotensor(out_t, "out"),
|
||||
mk_aotensor(dout_t, "dout"),
|
||||
mk_aotensor(dq_t, "dq"),
|
||||
mk_aotensor(dk_t, "dk"),
|
||||
mk_aotensor(dv_t, "dv"),
|
||||
bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4,
|
||||
mk_aotensor<2>(softmax_lse, "L"),
|
||||
mk_aotensor<2>(delta, "delta"),
|
||||
float(dropout_p),
|
||||
mk_aoscalartensor(philox_seed),
|
||||
mk_aoscalartensor(philox_offset),
|
||||
0,
|
||||
is_causal,
|
||||
stream);
|
||||
} else { // cu_seqlens.has_value
|
||||
auto d_head = Kv;
|
||||
bool use_fused_bwd = d_head <= 192 && d_head * max_seqlen_q < 64 * 512;
|
||||
if (use_fused_bwd) {
|
||||
err = attn_bwd_fused(mk_aotensor(q_t, "q"),
|
||||
mk_aotensor(k_t, "k"),
|
||||
mk_aotensor(v_t, "v"),
|
||||
bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4,
|
||||
softmax_scale,
|
||||
mk_aotensor(out_t, "out"),
|
||||
mk_aotensor(dout_t, "dout"),
|
||||
mk_aotensor(dq_t, "dq"),
|
||||
mk_aotensor(dk_t, "dk"),
|
||||
mk_aotensor(dv_t, "dv"),
|
||||
bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4,
|
||||
mk_aotensor<2>(softmax_lse, "L"),
|
||||
float(dropout_p),
|
||||
mk_aoscalartensor(philox_seed),
|
||||
mk_aoscalartensor(philox_offset),
|
||||
0,
|
||||
is_causal,
|
||||
stream);
|
||||
} else {
|
||||
at::Tensor delta = at::empty_like(softmax_lse).contiguous();
|
||||
err = attn_bwd(mk_aotensor(q_t, "q"),
|
||||
mk_aotensor(k_t, "k"),
|
||||
mk_aotensor(v_t, "v"),
|
||||
bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4,
|
||||
|
|
@ -518,9 +558,10 @@ _efficient_attention_backward(
|
|||
0,
|
||||
is_causal,
|
||||
stream);
|
||||
}
|
||||
}
|
||||
#else
|
||||
} //used_fused_bwd
|
||||
} // cuseqlen.has_value
|
||||
} // Use CK
|
||||
#else // USE_CUDA
|
||||
at::Tensor workspace;
|
||||
cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index());
|
||||
const int computeCapability = p->major * 10 + p->minor;
|
||||
|
|
|
|||
|
|
@ -787,4 +787,4 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
|
|||
}
|
||||
} // namespace pytorch_flash
|
||||
|
||||
#endif
|
||||
#endif // USE_FLASH_ATTENTION
|
||||
|
|
|
|||
|
|
@ -0,0 +1,120 @@
|
|||
#include <ATen/native/transformers/hip/flash_attn/flash_common_hip.hpp>
|
||||
#include <ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h>
|
||||
|
||||
#if defined(USE_CK_FLASH_ATTENTION)
|
||||
namespace pytorch_flash {
|
||||
std::tuple<
|
||||
at::Tensor, // dQ
|
||||
at::Tensor, // dK
|
||||
at::Tensor, // dV
|
||||
at::Tensor> // dBias
|
||||
mem_eff_backward_ck(
|
||||
const at::Tensor &dout,
|
||||
const at::Tensor &q,
|
||||
const at::Tensor &k,
|
||||
const at::Tensor &v,
|
||||
const at::Tensor &out,
|
||||
const at::Tensor &softmax_lse,
|
||||
const at::Tensor &dq_,
|
||||
const at::Tensor &dk_,
|
||||
const at::Tensor &dv_,
|
||||
std::optional<at::Tensor> &attn_bias,
|
||||
bool bias_requires_grad,
|
||||
std::optional<at::Tensor> &grad_bias,
|
||||
std::optional<at::Tensor> &cu_seqlens_q,
|
||||
std::optional<at::Tensor> &cu_seqlens_k,
|
||||
int max_seqlen_q,
|
||||
int max_seqlen_k,
|
||||
float p_dropout,
|
||||
float scale,
|
||||
bool is_causal,
|
||||
bool deterministic,
|
||||
bool zero_tensors,
|
||||
at::Tensor philox_seed,
|
||||
at::Tensor philox_offset)
|
||||
{
|
||||
|
||||
const int non_null_window_left = -1;
|
||||
const int non_null_window_right = -1;
|
||||
|
||||
std::optional<at::Tensor> opt_dQ, opt_dK, opt_dV;
|
||||
opt_dQ = dq_;
|
||||
opt_dK = dk_;
|
||||
opt_dV = dv_;
|
||||
|
||||
if(!cu_seqlens_q.has_value()) {
|
||||
auto
|
||||
[dQ,
|
||||
dK,
|
||||
dV,
|
||||
softmax_d,
|
||||
dBias] =
|
||||
mha_bwd_ck(
|
||||
dout,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
softmax_lse,
|
||||
opt_dQ,
|
||||
opt_dK,
|
||||
opt_dV,
|
||||
attn_bias,
|
||||
bias_requires_grad,
|
||||
grad_bias,
|
||||
p_dropout,
|
||||
scale,
|
||||
is_causal,
|
||||
non_null_window_left,
|
||||
non_null_window_right,
|
||||
deterministic,
|
||||
philox_seed,
|
||||
philox_offset);
|
||||
return std::make_tuple(std::move(dQ), std::move(dK), std::move(dV), std::move(dBias));
|
||||
|
||||
} else {
|
||||
// cu_seqlens only has a value in the nested tensor path which CK does not support
|
||||
TORCH_CHECK(false, "Nested Tensors not supported with CK backend.");
|
||||
return std::make_tuple(at::Tensor{}, at::Tensor{}, at::Tensor{}, at::Tensor{});
|
||||
// TODO: Fix nested tensor(varlen) path
|
||||
/*
|
||||
auto
|
||||
[dQ,
|
||||
dK,
|
||||
dV,
|
||||
softmax_d,
|
||||
dBias] =
|
||||
mha_varlen_bwd_ck(
|
||||
dout,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
softmax_lse,
|
||||
opt_dQ,
|
||||
opt_dK,
|
||||
opt_dV,
|
||||
cu_seqlens_q.value(),
|
||||
cu_seqlens_k.value(),
|
||||
attn_bias,
|
||||
bias_requires_grad,
|
||||
grad_bias,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
p_dropout,
|
||||
scale,
|
||||
zero_tensors,
|
||||
is_causal,
|
||||
non_null_window_left,
|
||||
non_null_window_right,
|
||||
deterministic,
|
||||
philox_seed,
|
||||
philox_offset);
|
||||
return std::make_tuple(std::move(dQ), std::move(dK), std::move(dV), std::move(dBias));
|
||||
*/
|
||||
}
|
||||
return std::make_tuple(at::Tensor{}, at::Tensor{}, at::Tensor{}, at::Tensor{});
|
||||
}
|
||||
|
||||
} // namespace pytorch_flash
|
||||
#endif // USE_CK_FLASH_ATTENTION
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
#pragma once
|
||||
#include <cstddef>
|
||||
|
||||
#include <ATen/core/Tensor.h>
|
||||
|
||||
#if defined(USE_CK_FLASH_ATTENTION)
|
||||
namespace pytorch_flash {
|
||||
|
||||
std::tuple<
|
||||
at::Tensor, // output
|
||||
at::Tensor, // q
|
||||
at::Tensor, // k
|
||||
at::Tensor, // v
|
||||
at::Tensor, // lse
|
||||
at::Tensor, // seed
|
||||
at::Tensor, // offset
|
||||
at::Tensor> // dropout randval
|
||||
mem_eff_forward_ck(
|
||||
const at::Tensor& q,
|
||||
const at::Tensor& k,
|
||||
const at::Tensor& v,
|
||||
float p_dropout,
|
||||
bool return_dropout_randval,
|
||||
std::optional<bool> is_causal,
|
||||
std::optional<float> scale,
|
||||
const std::optional<at::Tensor>& attn_bias_,
|
||||
std::optional<at::Tensor>& out_,
|
||||
const std::optional<at::Tensor>& cu_seqlens_q,
|
||||
const std::optional<at::Tensor>& cu_seqlens_k,
|
||||
const std::optional<at::Tensor>& seqstart_q,
|
||||
const std::optional<at::Tensor>& seqstart_k,
|
||||
std::optional<at::Generator> gen_,
|
||||
std::optional<at::Tensor>& seqused_k_
|
||||
);
|
||||
|
||||
std::tuple<
|
||||
at::Tensor, // dQ
|
||||
at::Tensor, // dK
|
||||
at::Tensor, // dV
|
||||
at::Tensor> // dBias
|
||||
mem_eff_backward_ck(
|
||||
const at::Tensor &dout,
|
||||
const at::Tensor &q,
|
||||
const at::Tensor &k,
|
||||
const at::Tensor &v,
|
||||
const at::Tensor &out,
|
||||
const at::Tensor &softmax_lse,
|
||||
const at::Tensor &dq_,
|
||||
const at::Tensor &dk_,
|
||||
const at::Tensor &dv_,
|
||||
std::optional<at::Tensor> &attn_bias,
|
||||
bool bias_requires_grad,
|
||||
std::optional<at::Tensor> &grad_bias,
|
||||
std::optional<at::Tensor> &cu_seqlens_q,
|
||||
std::optional<at::Tensor> &cu_seqlens_k,
|
||||
int max_seqlen_q,
|
||||
int max_seqlen_k,
|
||||
float p_dropout,
|
||||
float scale,
|
||||
bool is_causal,
|
||||
bool deterministic,
|
||||
bool zero_tensors,
|
||||
const at::Tensor philox_seed,
|
||||
const at::Tensor philox_offset);
|
||||
|
||||
} // namespace pytorch_flash
|
||||
#endif // USE_CK_FLASH_ATTENTION
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
#include <ATen/native/transformers/hip/flash_attn/flash_common_hip.hpp>
|
||||
#include <ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h>
|
||||
|
||||
#if defined(USE_CK_FLASH_ATTENTION)
|
||||
namespace pytorch_flash {
|
||||
std::tuple<
|
||||
at::Tensor, // output
|
||||
at::Tensor, // q
|
||||
at::Tensor, // k
|
||||
at::Tensor, // v
|
||||
at::Tensor, // lse
|
||||
at::Tensor, // seed
|
||||
at::Tensor, // offset
|
||||
at::Tensor> // dropout randval
|
||||
mem_eff_forward_ck(
|
||||
const at::Tensor& q,
|
||||
const at::Tensor& k,
|
||||
const at::Tensor& v,
|
||||
float p_dropout,
|
||||
bool return_dropout_randval,
|
||||
std::optional<bool> is_causal,
|
||||
std::optional<float> scale,
|
||||
const std::optional<at::Tensor>& attn_bias_,
|
||||
std::optional<at::Tensor>& out_,
|
||||
const std::optional<at::Tensor>& cu_seqlens_q,
|
||||
const std::optional<at::Tensor>& cu_seqlens_k,
|
||||
const std::optional<at::Tensor>& seqstart_q,
|
||||
const std::optional<at::Tensor>& seqstart_k,
|
||||
std::optional<at::Generator> gen_,
|
||||
std::optional<at::Tensor>& seqused_k_) {
|
||||
|
||||
const int non_null_window_left = -1;
|
||||
const int non_null_window_right = -1;
|
||||
|
||||
TORCH_CHECK(
|
||||
cu_seqlens_q.has_value() == cu_seqlens_k.has_value(),
|
||||
"cu_seqlens_q and cu_seqlens_k must be both set or both not set");
|
||||
|
||||
|
||||
if(!seqstart_q.has_value()){
|
||||
return mha_fwd_ck(
|
||||
q, // q
|
||||
k, // k
|
||||
v, // v
|
||||
out_, // opt(out_)
|
||||
p_dropout, // p_dropout
|
||||
scale.value(), // opt(softmax_scale)
|
||||
is_causal.value(), // opt(is_causal)
|
||||
non_null_window_left, // window_size_left
|
||||
non_null_window_right, // window_size_right
|
||||
false, // return_softmax/return_debug_mask
|
||||
gen_, // gen
|
||||
attn_bias_); // attn_bias
|
||||
} else {
|
||||
// seqstart_q is only set in nested tensor path which CK does not support
|
||||
TORCH_CHECK(false, "Nested Tensors not supported with CK backend.");
|
||||
return std::make_tuple(at::Tensor{},
|
||||
at::Tensor{},
|
||||
at::Tensor{},
|
||||
at::Tensor{},
|
||||
at::Tensor{},
|
||||
at::Tensor{},
|
||||
at::Tensor{},
|
||||
at::Tensor{});
|
||||
// TODO: Fix nested tensor(varlen) path
|
||||
/*
|
||||
// max sequence lengths are now at T.size(1) since q,k,v were all transposed
|
||||
// in _scaled_dot_product_efficient_attention_cuda
|
||||
const int64_t max_seqlen_q = q.size(1);
|
||||
const int64_t max_seqlen_k = k.size(1);
|
||||
|
||||
return mha_varlen_fwd_ck(
|
||||
q, // q
|
||||
k, // k
|
||||
v, // v
|
||||
out_, // opt(out)
|
||||
seqstart_q.value(), // cu_seqlens_q
|
||||
seqstart_k.value(), // cu_seqlens_k
|
||||
seqused_k_, // opt(seqused_k)
|
||||
max_seqlen_q, // max_seqlen_q
|
||||
max_seqlen_k, // max_seqlen_k
|
||||
p_dropout, // p_dropout
|
||||
scale.value(), // softmax_scale
|
||||
false, // zero_tensors
|
||||
is_causal.value(), // is_causal
|
||||
non_null_window_left, // window_size_left
|
||||
non_null_window_right, // window_size_right
|
||||
false, // return_softmax/return_debug_mask
|
||||
gen_, // gen
|
||||
attn_bias_); // attn_bias
|
||||
*/
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace pytorch_flash
|
||||
#endif // USE_CK_FLASH_ATTENTION
|
||||
|
|
@ -12,16 +12,17 @@ fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask,
|
|||
std::string dtype,
|
||||
int head_size,
|
||||
bool has_dropout,
|
||||
bool enable_alibi,
|
||||
bool deterministic)
|
||||
bool enable_bias,
|
||||
bool deterministic,
|
||||
bool bias_requires_grad)
|
||||
{
|
||||
return fmha_bwd_traits{head_size,
|
||||
head_size,
|
||||
dtype,
|
||||
false, // is_group_mode
|
||||
mask.type,
|
||||
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
|
||||
false, // has_dbias
|
||||
enable_bias ? bias_enum::elementwise_bias : bias_enum::no_bias,
|
||||
bias_requires_grad, // has_dbias
|
||||
has_dropout,
|
||||
false, // s_randval
|
||||
deterministic};
|
||||
|
|
@ -39,7 +40,9 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
|
|||
const at::Tensor q,
|
||||
const at::Tensor k,
|
||||
const at::Tensor v,
|
||||
std::optional<at::Tensor> &alibi_slopes_,
|
||||
std::optional<at::Tensor> &attn_bias_,
|
||||
bool bias_requires_grad,
|
||||
std::optional<at::Tensor> &grad_bias,
|
||||
const at::Tensor out,
|
||||
const at::Tensor softmax_lse,
|
||||
const at::Tensor dout,
|
||||
|
|
@ -82,7 +85,6 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
|
|||
ck_tile::index_t nhead_stride_do = dout.stride(2);
|
||||
|
||||
// d: (batch_size, nheads, seqlen_q)
|
||||
// CK assume d share the same stride with lse
|
||||
|
||||
// dq: (batch_size, seqlen_q, nheads, hdim)
|
||||
ck_tile::index_t batch_stride_dq = dq.stride(0);
|
||||
|
|
@ -105,86 +107,103 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
|
|||
ck_tile::index_t stride_dq_acc = dq_acc.stride(2);
|
||||
ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(3);
|
||||
|
||||
float p_undrop = 1.0 - p_dropout;
|
||||
|
||||
void *alibi_slopes_ptr = nullptr;
|
||||
ck_tile::index_t stride_alibi_slopes = 0;
|
||||
|
||||
if (alibi_slopes_.has_value()) {
|
||||
auto alibi_slopes = alibi_slopes_.value();
|
||||
CHECK_DEVICE(alibi_slopes);
|
||||
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(alibi_slopes.sizes() == at::IntArrayRef({h}) || alibi_slopes.sizes() == at::IntArrayRef({b, h}));
|
||||
alibi_slopes_ptr = alibi_slopes.data_ptr();
|
||||
// alibi_slopes:(batch_size, nheads) or (nhead)
|
||||
stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
|
||||
// bias: (batch_size, nheads, seqlen_q, seqlen_k)
|
||||
void *attn_bias_ptr = nullptr;
|
||||
ck_tile::index_t nhead_stride_bias = 0;
|
||||
ck_tile::index_t batch_stride_bias = 0;
|
||||
ck_tile::index_t stride_attn_bias = 0;
|
||||
if (attn_bias_.has_value()) {
|
||||
auto a_b = attn_bias_.value();
|
||||
CHECK_DEVICE(a_b);
|
||||
TORCH_CHECK(a_b.stride(-1) == 1, "Attention bias tensor must have contiguous last dimension");
|
||||
attn_bias_ptr = a_b.data_ptr();
|
||||
stride_attn_bias = a_b.stride(2);
|
||||
nhead_stride_bias = a_b.stride(1);
|
||||
batch_stride_bias = a_b.stride(0);
|
||||
}
|
||||
|
||||
// dbias: (batch_size, nheads, seqlen_q, seqlen_k)
|
||||
void *dbias_ptr = nullptr;
|
||||
ck_tile::index_t stride_dbias = 0;
|
||||
ck_tile::index_t nhead_stride_dbias = 0;
|
||||
ck_tile::index_t batch_stride_dbias = 0;
|
||||
if(bias_requires_grad) {
|
||||
// If bias_requires_grad is true, grad_bias is guaranteed to have a value via line 270
|
||||
//grad_bias
|
||||
auto dbias = grad_bias.value();
|
||||
dbias_ptr = dbias.data_ptr();
|
||||
stride_dbias = dbias.stride(2);
|
||||
nhead_stride_dbias = dbias.stride(1);
|
||||
batch_stride_dbias = dbias.stride(0);
|
||||
}
|
||||
|
||||
float p_undrop = 1.0 - p_dropout;
|
||||
|
||||
return fmha_bwd_args{q.data_ptr(),
|
||||
k.data_ptr(),
|
||||
v.data_ptr(),
|
||||
alibi_slopes_ptr, // bias
|
||||
attn_bias_ptr, // bias
|
||||
out.data_ptr(),
|
||||
softmax_lse.data_ptr(),
|
||||
dout.data_ptr(),
|
||||
d.data_ptr(),
|
||||
nullptr, // rand_val
|
||||
nullptr, // rand_val
|
||||
dq.data_ptr(),
|
||||
dk.data_ptr(),
|
||||
dv.data_ptr(),
|
||||
nullptr, // dbias
|
||||
dq_acc.data_ptr(), // dq_acc
|
||||
nullptr, // seqstart_q
|
||||
nullptr, // seqstart_k
|
||||
nullptr, // seqlen_k_ptr
|
||||
dbias_ptr, // dbias
|
||||
dq_acc.data_ptr(), // dq_acc
|
||||
nullptr, // seqstart_q
|
||||
nullptr, // seqstart_k
|
||||
nullptr, // seqlen_k_ptr
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
b,
|
||||
seqlen_q, // max_seqlen_q
|
||||
seqlen_k, // max_seqlen_k
|
||||
hdim, // hdim_q
|
||||
hdim, // hdim_v
|
||||
h, // nhead
|
||||
h_k, // nhead_k
|
||||
seqlen_q, // max_seqlen_q
|
||||
seqlen_k, // max_seqlen_k
|
||||
hdim, // hdim_q
|
||||
hdim, // hdim_v
|
||||
h, // nhead
|
||||
h_k, // nhead_k
|
||||
softmax_scale,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_alibi_slopes,
|
||||
stride_attn_bias,
|
||||
stride_o,
|
||||
0, // stride_randval
|
||||
0, // stride_randval
|
||||
stride_do,
|
||||
stride_dq_acc,
|
||||
stride_dq,
|
||||
stride_dk,
|
||||
stride_dv,
|
||||
0, // stride_dbias, FA without bias
|
||||
stride_dbias, // stride_dbias
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
0, // nhead_stride_bias, FA without bias
|
||||
nhead_stride_bias, // nhead_stride_bias
|
||||
nhead_stride_o,
|
||||
0, // nhead_stride_randval
|
||||
0, // nhead_stride_randval
|
||||
nhead_stride_do,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_dq_acc,
|
||||
nhead_stride_dq,
|
||||
nhead_stride_dk,
|
||||
nhead_stride_dv,
|
||||
0, // nhead_stride_dbias, FA without dbias
|
||||
nhead_stride_dbias, // nhead_stride_dbias
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
0 , // batch_stride_bias, FA without bias
|
||||
batch_stride_bias, // batch_stride_bias
|
||||
batch_stride_o,
|
||||
0, // batch_stride_randval
|
||||
0, // batch_stride_randval
|
||||
batch_stride_do,
|
||||
batch_stride_lse,
|
||||
batch_stride_dq_acc,
|
||||
batch_stride_dq,
|
||||
batch_stride_dk,
|
||||
batch_stride_dv,
|
||||
0 , // batch_stride_dbias, FA without dbias
|
||||
batch_stride_dbias, // batch_stride_dbias
|
||||
split_stride_dq_acc,
|
||||
mask.left,
|
||||
mask.right,
|
||||
|
|
@ -193,8 +212,7 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
|
|||
p_undrop,
|
||||
drop_seed_offset};
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
|
||||
const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
|
|
@ -204,7 +222,9 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x
|
|||
std::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
|
||||
std::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
std::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
|
||||
std::optional<at::Tensor> &attn_bias_, // num_heads or batch_size x num_heads
|
||||
bool bias_requires_grad,
|
||||
std::optional<at::Tensor> &grad_bias,
|
||||
const float p_dropout, // probability to drop
|
||||
const float softmax_scale,
|
||||
const bool is_causal,
|
||||
|
|
@ -242,6 +262,9 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x
|
|||
TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
|
||||
|
||||
TORCH_CHECK((bias_requires_grad && grad_bias.has_value()) || (!bias_requires_grad),
|
||||
"If bias_requires_grad is set, grad_bias must have a value");
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
|
||||
const int batch_size = sizes[0];
|
||||
|
|
@ -354,7 +377,13 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x
|
|||
ck_tile::stream_config stream_config{stream};
|
||||
dq.zero_(); // ck use atomic operation on dq
|
||||
auto traits =
|
||||
get_ck_fmha_bwd_traits(mask, q_dtype_str, head_size_8x, is_dropout, alibi_slopes_.has_value(), deterministic);
|
||||
get_ck_fmha_bwd_traits(mask,
|
||||
q_dtype_str,
|
||||
head_size_8x,
|
||||
is_dropout,
|
||||
attn_bias_.has_value(),
|
||||
deterministic,
|
||||
bias_requires_grad);
|
||||
|
||||
auto args =
|
||||
get_ck_fmha_bwd_args(
|
||||
|
|
@ -368,7 +397,9 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x
|
|||
q,
|
||||
k,
|
||||
v,
|
||||
alibi_slopes_,
|
||||
attn_bias_,
|
||||
bias_requires_grad,
|
||||
grad_bias,
|
||||
out,
|
||||
softmax_lse,
|
||||
dout_padded,
|
||||
|
|
@ -400,6 +431,14 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x
|
|||
dv = dv.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)});
|
||||
}
|
||||
|
||||
return { dq, dk, dv, softmax_d };
|
||||
at::Tensor dbias;
|
||||
if(bias_requires_grad) {
|
||||
dbias = grad_bias.value();
|
||||
} else {
|
||||
dbias = at::empty({batch_size, num_heads, seqlen_q, seqlen_k}, q.options());
|
||||
}
|
||||
|
||||
|
||||
return { dq, dk, dv, softmax_d, dbias };
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@
|
|||
#include <mask.hpp>
|
||||
|
||||
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
|
||||
|
|
@ -16,7 +15,7 @@ fmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask,
|
|||
int head_size,
|
||||
bool has_dropout,
|
||||
bool has_lse,
|
||||
bool enable_alibi)
|
||||
bool enable_bias)
|
||||
{
|
||||
return fmha_fwd_traits{head_size,
|
||||
head_size,
|
||||
|
|
@ -24,7 +23,7 @@ fmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask,
|
|||
false, // is_group_mode
|
||||
true, // is_v_rowmajor
|
||||
mask.type,
|
||||
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
|
||||
enable_bias ? bias_enum::elementwise_bias : bias_enum::no_bias,
|
||||
has_lse,
|
||||
has_dropout,
|
||||
false}; // do_fp8_static_quant
|
||||
|
|
@ -44,7 +43,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
|
|||
const at::Tensor q,
|
||||
const at::Tensor k,
|
||||
const at::Tensor v,
|
||||
std::optional<at::Tensor> &alibi_slopes_,
|
||||
std::optional<at::Tensor> &attn_bias_,
|
||||
at::Tensor out,
|
||||
at::Tensor softmax_lse,
|
||||
at::Tensor dropout_randval,
|
||||
|
|
@ -57,7 +56,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
|
|||
// v: (batch_size, seqlen_k, nheads_k, d)
|
||||
// o: (batch_size, seqlen_q, nheads, d)
|
||||
|
||||
// alibi_slopes:(batch_size, nheads) or (nhead)
|
||||
// attn_bias: (batch_size, nheads, seqlen_q, seqlen_k)
|
||||
// lse: (batch_size, nheads, seqlen_q)
|
||||
// randval: (batch_size, nheads, seqlen_q, seqlen_k)
|
||||
|
||||
|
|
@ -82,56 +81,58 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
|
|||
ck_tile::index_t batch_stride_lse = has_lse ? softmax_lse.stride(0) : 0;
|
||||
ck_tile::index_t batch_stride_randval = has_dropout_randval ? dropout_randval.stride(0) : 0;
|
||||
|
||||
void *alibi_slopes_ptr = nullptr;
|
||||
ck_tile::index_t stride_alibi_slopes = 0;
|
||||
|
||||
if (alibi_slopes_.has_value()) {
|
||||
auto alibi_slopes = alibi_slopes_.value();
|
||||
CHECK_DEVICE(alibi_slopes);
|
||||
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(alibi_slopes.sizes() == at::IntArrayRef({h}) || alibi_slopes.sizes() == at::IntArrayRef({b, h}));
|
||||
alibi_slopes_ptr = alibi_slopes.data_ptr();
|
||||
stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
|
||||
void *attn_bias_ptr = nullptr;
|
||||
ck_tile::index_t stride_attn_bias = 0;
|
||||
ck_tile::index_t batch_stride_bias = 0;
|
||||
ck_tile::index_t nhead_stride_bias = 0;
|
||||
if (attn_bias_.has_value()) {
|
||||
auto a_b = attn_bias_.value();
|
||||
CHECK_DEVICE(a_b);
|
||||
TORCH_CHECK(a_b.stride(-1) == 1, "attention bias tensor must have contiguous last dimension");
|
||||
attn_bias_ptr = a_b.data_ptr();
|
||||
stride_attn_bias = a_b.stride(2);
|
||||
nhead_stride_bias = a_b.stride(1);
|
||||
batch_stride_bias = a_b.stride(0);
|
||||
}
|
||||
|
||||
return fmha_fwd_args{q.data_ptr(),
|
||||
k.data_ptr(),
|
||||
v.data_ptr(),
|
||||
alibi_slopes_ptr, // bias
|
||||
attn_bias_ptr, // bias
|
||||
has_dropout_randval ? dropout_randval.data_ptr() : nullptr,
|
||||
has_lse ? softmax_lse.data_ptr() : nullptr,
|
||||
out.data_ptr(),
|
||||
nullptr, // seqstart_q
|
||||
nullptr, // seqstart_k
|
||||
nullptr, // seqstart_q
|
||||
nullptr, // seqstart_k
|
||||
nullptr,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
b,
|
||||
seqlen_q, // max_seqlen_q
|
||||
d, // hdim_q
|
||||
d, // hdim_v
|
||||
h, // nhead
|
||||
h_k, // nhead_k
|
||||
softmax_scale, // scale_s
|
||||
1, // scale_p
|
||||
1, // scale_o
|
||||
seqlen_q, // max_seqlen_q
|
||||
d, // hdim_q
|
||||
d, // hdim_v
|
||||
h, // nhead
|
||||
h_k, // nhead_k
|
||||
softmax_scale, // scale_s
|
||||
1, // scale_p
|
||||
1, // scale_o
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_alibi_slopes,
|
||||
stride_attn_bias,
|
||||
stride_randval,
|
||||
stride_o,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
0, // nhead_stride_bias, FA without bias
|
||||
nhead_stride_bias, // nhead_stride_bias
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
0, // batch_stride_bias, FA without bias
|
||||
batch_stride_bias, // batch_stride_bias
|
||||
batch_stride_randval,
|
||||
batch_stride_lse,
|
||||
batch_stride_o,
|
||||
|
|
@ -148,14 +149,14 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
|
|||
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads xhead_size
|
||||
std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
|
||||
const float p_dropout,
|
||||
const float softmax_scale,
|
||||
bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const bool return_dropout_randval,
|
||||
std::optional<at::Generator> gen_)
|
||||
std::optional<at::Generator> gen_,
|
||||
const std::optional<at::Tensor>& attn_bias_) // batch_size x nheads x seqlen_q x seqlen_k
|
||||
{
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
|
||||
|
|
@ -189,7 +190,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
|
|||
if (window_size_right >= seqlen_k) { window_size_right = -1; }
|
||||
|
||||
// causal=true is the same as causal=false in this case
|
||||
if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
|
||||
if (seqlen_q == 1 && !attn_bias_.has_value()) { is_causal = false; }
|
||||
|
||||
mask_info mask;
|
||||
if (is_causal) {
|
||||
|
|
@ -209,7 +210,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
|
|||
|
||||
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
|
||||
// H/t Daniel Haziza
|
||||
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value();
|
||||
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !attn_bias_.has_value();
|
||||
const int ngroups = num_heads / num_heads_k;
|
||||
at::Tensor temp_q = q;
|
||||
if (seqlenq_ngroups_swapped) {
|
||||
|
|
@ -305,6 +306,12 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
|
|||
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||
}
|
||||
|
||||
std::optional<at::Tensor> attn_bias;
|
||||
if( attn_bias_.has_value())
|
||||
{
|
||||
attn_bias = attn_bias_;
|
||||
}
|
||||
|
||||
if (seqlen_k > 0) {
|
||||
auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
|
||||
auto stream = at::cuda::getCurrentHIPStream().stream();
|
||||
|
|
@ -317,7 +324,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
|
|||
head_size_8x,
|
||||
has_dropout,
|
||||
has_lse,
|
||||
alibi_slopes_.has_value());
|
||||
attn_bias_.has_value());
|
||||
|
||||
auto args =
|
||||
get_ck_fmha_fwd_args(
|
||||
|
|
@ -333,7 +340,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
|
|||
q,
|
||||
k,
|
||||
v,
|
||||
alibi_slopes_,
|
||||
attn_bias,
|
||||
out,
|
||||
softmax_lse,
|
||||
p,
|
||||
|
|
|
|||
|
|
@ -14,16 +14,17 @@ fmha_bwd_traits get_ck_fmha_varlen_bwd_traits(const mask_info &mask,
|
|||
std::string dtype,
|
||||
int head_size,
|
||||
bool has_dropout,
|
||||
bool enable_alibi,
|
||||
bool deterministic)
|
||||
bool enable_bias,
|
||||
bool deterministic,
|
||||
bool bias_requires_grad)
|
||||
{
|
||||
return fmha_bwd_traits{head_size,
|
||||
head_size,
|
||||
dtype,
|
||||
true, // is_group_mode
|
||||
mask.type,
|
||||
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
|
||||
false, // has_dbias
|
||||
enable_bias ? bias_enum::elementwise_bias : bias_enum::no_bias,
|
||||
bias_requires_grad, // has_dbias
|
||||
has_dropout,
|
||||
false, // s_randval
|
||||
deterministic};
|
||||
|
|
@ -43,7 +44,9 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
|
|||
const at::Tensor v,
|
||||
const at::Tensor seqlens_q,
|
||||
const at::Tensor seqlens_k,
|
||||
std::optional<at::Tensor> &alibi_slopes_,
|
||||
std::optional<at::Tensor> &attn_bias_,
|
||||
bool bias_requires_grad,
|
||||
std::optional<at::Tensor> &grad_bias,
|
||||
const at::Tensor out,
|
||||
const at::Tensor softmax_lse,
|
||||
const at::Tensor dout,
|
||||
|
|
@ -115,23 +118,40 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
|
|||
|
||||
float p_undrop = 1.0 - p_dropout;
|
||||
|
||||
void *alibi_slopes_ptr = nullptr;
|
||||
ck_tile::index_t stride_alibi_slopes = 0;
|
||||
// bias: (batch_size, nheads, seqlen_q, seqlen_k)
|
||||
void *attn_bias_ptr = nullptr;
|
||||
ck_tile::index_t nhead_stride_bias = 0;
|
||||
ck_tile::index_t batch_stride_bias = 0;
|
||||
ck_tile::index_t stride_attn_bias = 0;
|
||||
if (attn_bias_.has_value()) {
|
||||
auto a_b = attn_bias_.value();
|
||||
CHECK_DEVICE(a_b);
|
||||
TORCH_CHECK(a_b.stride(-1) == 1, "Attention bias tensor must have contiguous last dimension");
|
||||
attn_bias_ptr = a_b.data_ptr();
|
||||
stride_attn_bias = a_b.stride(2);
|
||||
nhead_stride_bias = a_b.stride(1);
|
||||
batch_stride_bias = a_b.stride(0);
|
||||
}
|
||||
|
||||
if (alibi_slopes_.has_value()) {
|
||||
auto alibi_slopes = alibi_slopes_.value();
|
||||
CHECK_DEVICE(alibi_slopes);
|
||||
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(alibi_slopes.sizes() == at::IntArrayRef({h}) || alibi_slopes.sizes() == at::IntArrayRef({b, h}));
|
||||
alibi_slopes_ptr = alibi_slopes.data_ptr();
|
||||
// alibi_slopes:(batch_size, nheads) or (nhead)
|
||||
stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
|
||||
void *dbias_ptr = nullptr;
|
||||
ck_tile::index_t stride_dbias = 0;
|
||||
ck_tile::index_t nhead_stride_dbias = 0;
|
||||
ck_tile::index_t batch_stride_dbias = 0;
|
||||
// dbias: (batch_size, nheads, seqlen_q, seqlen_k)
|
||||
if(bias_requires_grad) {
|
||||
// If bias_requires_grad is true, grad_bias is guaranteed to have a value via line 270
|
||||
//grad_bias
|
||||
auto dbias = grad_bias.value();
|
||||
dbias_ptr = dbias.data_ptr();
|
||||
stride_dbias = dbias.stride(2);
|
||||
nhead_stride_dbias = dbias.stride(1);
|
||||
batch_stride_dbias = dbias.stride(0);
|
||||
}
|
||||
|
||||
return fmha_bwd_args{q.data_ptr(),
|
||||
k.data_ptr(),
|
||||
v.data_ptr(),
|
||||
alibi_slopes_ptr, // bias
|
||||
attn_bias_ptr, // bias
|
||||
out.data_ptr(),
|
||||
softmax_lse.data_ptr(),
|
||||
dout.data_ptr(),
|
||||
|
|
@ -140,7 +160,7 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
|
|||
dq.data_ptr(),
|
||||
dk.data_ptr(),
|
||||
dv.data_ptr(),
|
||||
nullptr, // dbias
|
||||
dbias_ptr, // dbias
|
||||
dq_acc.data_ptr(), // dq_acc
|
||||
seqlens_q.data_ptr(), // seqstart_q
|
||||
seqlens_k.data_ptr(), // seqstart_k
|
||||
|
|
@ -158,7 +178,7 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
|
|||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_alibi_slopes,
|
||||
stride_attn_bias,
|
||||
stride_o,
|
||||
0, // stride_randval
|
||||
stride_do,
|
||||
|
|
@ -166,11 +186,11 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
|
|||
stride_dq,
|
||||
stride_dk,
|
||||
stride_dv,
|
||||
0, // stride_dbias, FA without bias
|
||||
stride_dbias, // stride_dbias
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
0, // nhead_stride_bias, FA without bias
|
||||
nhead_stride_bias, // nhead_stride_bias
|
||||
nhead_stride_o,
|
||||
0, // nhead_stride_randval
|
||||
nhead_stride_do,
|
||||
|
|
@ -179,11 +199,11 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
|
|||
nhead_stride_dq,
|
||||
nhead_stride_dk,
|
||||
nhead_stride_dv,
|
||||
0, // nhead_stride_dbias, FA without dbias
|
||||
nhead_stride_dbias, // nhead_stride_dbias
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
0 , // batch_stride_bias, FA without bias
|
||||
batch_stride_bias, // batch_stride_bias
|
||||
batch_stride_o,
|
||||
0, // batch_stride_randval
|
||||
batch_stride_do,
|
||||
|
|
@ -192,7 +212,7 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
|
|||
batch_stride_dq,
|
||||
batch_stride_dk,
|
||||
batch_stride_dv,
|
||||
0 , // batch_stride_dbias, FA without dbias
|
||||
batch_stride_dbias, // batch_stride_dbias
|
||||
split_stride_dq_acc,
|
||||
mask.left,
|
||||
mask.right,
|
||||
|
|
@ -202,7 +222,7 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
|
|||
drop_seed_offset};
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
mha_varlen_bwd_ck(const at::Tensor &dout, // total_q x num_heads x head_size
|
||||
const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
|
|
@ -214,7 +234,9 @@ mha_varlen_bwd_ck(const at::Tensor &dout, // total_q x num_hea
|
|||
std::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &cu_seqlens_q, // b+1
|
||||
const at::Tensor &cu_seqlens_k, // b+1
|
||||
std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
|
||||
std::optional<at::Tensor> &attn_bias_, // b x num_heads x seqlen_q x seqlen_k
|
||||
bool bias_requires_grad,
|
||||
std::optional<at::Tensor> &grad_bias,
|
||||
const int max_seqlen_q,
|
||||
const int max_seqlen_k, // max sequence length to choose the kernel
|
||||
const float p_dropout, // probability to drop
|
||||
|
|
@ -260,6 +282,9 @@ mha_varlen_bwd_ck(const at::Tensor &dout, // total_q x num_hea
|
|||
CHECK_CONTIGUOUS(cu_seqlens_q);
|
||||
CHECK_CONTIGUOUS(cu_seqlens_k);
|
||||
|
||||
TORCH_CHECK((bias_requires_grad && grad_bias.has_value()) || (!bias_requires_grad),
|
||||
"If bias_requires_grad is set, grad_bias must have a value");
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
|
||||
const int total_q = sizes[0];
|
||||
|
|
@ -381,7 +406,13 @@ mha_varlen_bwd_ck(const at::Tensor &dout, // total_q x num_hea
|
|||
ck_tile::stream_config stream_config{stream};
|
||||
dq.zero_(); // ck use atomic operation on dq
|
||||
auto traits =
|
||||
get_ck_fmha_varlen_bwd_traits(mask, q_dtype_str, head_size_8x, is_dropout, alibi_slopes_.has_value(), deterministic);
|
||||
get_ck_fmha_varlen_bwd_traits(mask,
|
||||
q_dtype_str,
|
||||
head_size_8x,
|
||||
is_dropout,
|
||||
attn_bias_.has_value(),
|
||||
deterministic,
|
||||
bias_requires_grad);
|
||||
|
||||
auto args =
|
||||
get_ck_fmha_varlen_bwd_args(
|
||||
|
|
@ -397,7 +428,9 @@ mha_varlen_bwd_ck(const at::Tensor &dout, // total_q x num_hea
|
|||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
alibi_slopes_,
|
||||
attn_bias_,
|
||||
bias_requires_grad,
|
||||
grad_bias,
|
||||
out,
|
||||
softmax_lse,
|
||||
dout_padded,
|
||||
|
|
@ -428,7 +461,14 @@ mha_varlen_bwd_ck(const at::Tensor &dout, // total_q x num_hea
|
|||
dk = dk.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)});
|
||||
dv = dv.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)});
|
||||
}
|
||||
at::Tensor dbias;
|
||||
if(bias_requires_grad) {
|
||||
dbias = grad_bias.value();
|
||||
} else {
|
||||
dbias = at::empty({batch_size, num_heads, max_seqlen_q, max_seqlen_k}, q.options());
|
||||
}
|
||||
|
||||
return { dq, dk, dv, softmax_d };
|
||||
|
||||
return { dq, dk, dv, softmax_d, dbias };
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask,
|
|||
int head_size,
|
||||
bool has_dropout,
|
||||
bool has_lse,
|
||||
bool enable_alibi)
|
||||
bool enable_bias)
|
||||
{
|
||||
return fmha_fwd_traits{head_size,
|
||||
head_size,
|
||||
|
|
@ -21,7 +21,7 @@ fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask,
|
|||
true, // is_group_mode
|
||||
true, // is_v_rowmajor
|
||||
mask.type,
|
||||
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
|
||||
enable_bias ? bias_enum::elementwise_bias : bias_enum::no_bias,
|
||||
has_lse,
|
||||
has_dropout,
|
||||
false}; // do_fp8_static_quant
|
||||
|
|
@ -42,11 +42,10 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
|
|||
const at::Tensor v,
|
||||
const at::Tensor seqlens_q,
|
||||
const at::Tensor seqlens_k,
|
||||
std::optional<at::Tensor> &alibi_slopes_,
|
||||
std::optional<at::Tensor> &attn_bias_,
|
||||
at::Tensor out,
|
||||
at::Tensor softmax_lse,
|
||||
at::Tensor dropout_randval,
|
||||
|
||||
float softmax_scale,
|
||||
float p_dropout,
|
||||
std::pair<uint64_t*, uint64_t*> drop_seed_offset)
|
||||
|
|
@ -56,7 +55,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
|
|||
// v: (total_k, nheads_k, d)
|
||||
// o: (total_q, nheads, d)
|
||||
|
||||
// alibi_slopes:(batch, nheads) or (nhead)
|
||||
// attn_bias :(batch, nheads, max_seqlen_q, max_seqlen_k)
|
||||
// lse: (batch, nheads, max_seqlen_q)
|
||||
// randval: (nheads, total_q, max_seqlen_k)
|
||||
|
||||
|
|
@ -84,22 +83,23 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
|
|||
ck_tile::index_t batch_stride_lse = has_lse ? softmax_lse.stride(0) : 0;
|
||||
ck_tile::index_t batch_stride_randval = 0;
|
||||
|
||||
void *alibi_slopes_ptr = nullptr;
|
||||
ck_tile::index_t stride_alibi_slopes = 0;
|
||||
void *attn_bias_ptr = nullptr;
|
||||
ck_tile::index_t stride_attn_bias = 0;
|
||||
|
||||
if (alibi_slopes_.has_value()) {
|
||||
auto alibi_slopes = alibi_slopes_.value();
|
||||
CHECK_DEVICE(alibi_slopes);
|
||||
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(alibi_slopes.sizes() == at::IntArrayRef({h}) || alibi_slopes.sizes() == at::IntArrayRef({b, h}));
|
||||
alibi_slopes_ptr = alibi_slopes.data_ptr();
|
||||
stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
|
||||
if (attn_bias_.has_value()) {
|
||||
auto a_b = attn_bias_.value();
|
||||
CHECK_DEVICE(a_b);
|
||||
TORCH_CHECK(a_b.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
|
||||
//TORCH_CHECK(alibi_slopes.sizes() == at::IntArrayRef({h}) || alibi_slopes.sizes() == at::IntArrayRef({b, h}));
|
||||
attn_bias_ptr = a_b.data_ptr();
|
||||
//stride_attn_bias = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
|
||||
stride_attn_bias = a_b.stride(0);
|
||||
}
|
||||
|
||||
return fmha_fwd_args{q.data_ptr(),
|
||||
k.data_ptr(),
|
||||
v.data_ptr(),
|
||||
alibi_slopes_ptr, // bias
|
||||
attn_bias_ptr, // bias
|
||||
has_dropout_randval ? dropout_randval.data_ptr() : nullptr,
|
||||
has_lse ? softmax_lse.data_ptr() : nullptr,
|
||||
out.data_ptr(),
|
||||
|
|
@ -120,7 +120,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
|
|||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_alibi_slopes,
|
||||
stride_attn_bias,
|
||||
stride_randval,
|
||||
stride_o,
|
||||
nhead_stride_q,
|
||||
|
|
@ -153,7 +153,6 @@ mha_varlen_fwd_ck(const at::Tensor &q, // total_q x num_heads
|
|||
const at::Tensor &cu_seqlens_q, // b+1
|
||||
const at::Tensor &cu_seqlens_k, // b+1
|
||||
std::optional<at::Tensor> & /*seqused_k*/,
|
||||
std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
|
||||
int max_seqlen_q,
|
||||
const int max_seqlen_k,
|
||||
const float p_dropout,
|
||||
|
|
@ -163,7 +162,8 @@ mha_varlen_fwd_ck(const at::Tensor &q, // total_q x num_heads
|
|||
int window_size_left,
|
||||
int window_size_right,
|
||||
const bool return_dropout_randval,
|
||||
std::optional<at::Generator> gen_)
|
||||
std::optional<at::Generator> gen_,
|
||||
const std::optional<at::Tensor>& attn_bias_)
|
||||
{
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
|
||||
|
|
@ -200,7 +200,7 @@ mha_varlen_fwd_ck(const at::Tensor &q, // total_q x num_heads
|
|||
const int max_num_blocks_per_seq = 0;
|
||||
const int num_blocks = 0;
|
||||
|
||||
if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case
|
||||
if (max_seqlen_q == 1 && !attn_bias_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case
|
||||
|
||||
// TODO
|
||||
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
|
||||
|
|
@ -307,6 +307,13 @@ mha_varlen_fwd_ck(const at::Tensor &q, // total_q x num_heads
|
|||
flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), philox_args, rng_state_ptr);
|
||||
}
|
||||
|
||||
// remove const from attn_bias_
|
||||
std::optional<at::Tensor> attn_bias;
|
||||
if( attn_bias_.has_value())
|
||||
{
|
||||
attn_bias = attn_bias_;
|
||||
}
|
||||
|
||||
|
||||
if (max_seqlen_k > 0) {
|
||||
auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
|
||||
|
|
@ -314,7 +321,7 @@ mha_varlen_fwd_ck(const at::Tensor &q, // total_q x num_heads
|
|||
ck_tile::stream_config stream_config{stream};
|
||||
|
||||
auto traits =
|
||||
get_ck_fmha_varlen_fwd_traits(mask, q_dtype_str, head_size_8x, has_dropout, has_lse, alibi_slopes_.has_value());
|
||||
get_ck_fmha_varlen_fwd_traits(mask, q_dtype_str, head_size_8x, has_dropout, has_lse, attn_bias_.has_value());
|
||||
|
||||
auto args =
|
||||
get_ck_fmha_varlen_fwd_args(
|
||||
|
|
@ -331,7 +338,7 @@ mha_varlen_fwd_ck(const at::Tensor &q, // total_q x num_heads
|
|||
v_padded,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
alibi_slopes_,
|
||||
attn_bias,
|
||||
out,
|
||||
softmax_lse,
|
||||
p,
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,11 +0,0 @@
|
|||
#!/bin/bash
|
||||
set -ex
|
||||
|
||||
file_renaming_txt="rename_ck_autogen_files.output.txt"
|
||||
rm -rf $file_renaming_txt
|
||||
for file in `ls fmha_*wd*hip`; do
|
||||
sha1=$(sha1sum $file | cut -d' ' -f1)
|
||||
new_file="fmha_ck_autogen_${sha1}.hip"
|
||||
mv $file $new_file
|
||||
echo "$file -> $new_file" >> $file_renaming_txt
|
||||
done
|
||||
|
|
@ -143,15 +143,14 @@ mha_fwd_ck(
|
|||
const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
std::optional<at::Tensor>&
|
||||
out_, // batch_size x seqlen_q x num_heads x head_size
|
||||
std::optional<at::Tensor>&
|
||||
alibi_slopes_, // num_heads or batch_size x num_heads
|
||||
const float p_dropout,
|
||||
const float softmax_scale,
|
||||
bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const bool return_softmax,
|
||||
std::optional<at::Generator> gen_);
|
||||
std::optional<at::Generator> gen_,
|
||||
const std::optional<at::Tensor>& attn_bias_); // batch_size x nheads x seqlen_q x seqlen_k
|
||||
|
||||
std::tuple<
|
||||
at::Tensor,
|
||||
|
|
@ -176,7 +175,6 @@ mha_varlen_fwd_ck(
|
|||
std::optional<at::Tensor>&
|
||||
seqused_k, // b. If given, only this many elements of each batch
|
||||
// element's keys are used.
|
||||
std::optional<at::Tensor>& alibi_slopes_, // num_heads or b x num_heads
|
||||
int max_seqlen_q,
|
||||
const int max_seqlen_k,
|
||||
const float p_dropout,
|
||||
|
|
@ -186,9 +184,10 @@ mha_varlen_fwd_ck(
|
|||
int window_size_left,
|
||||
int window_size_right,
|
||||
const bool return_softmax,
|
||||
std::optional<at::Generator> gen_);
|
||||
std::optional<at::Generator> gen_,
|
||||
const std::optional<at::Tensor>& attn_bias_);
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd_ck(
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd_ck(
|
||||
const at::Tensor& dout, // batch_size x seqlen_q x num_heads, x head_size_og
|
||||
const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
|
|
@ -202,7 +201,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd_ck(
|
|||
std::optional<at::Tensor>&
|
||||
dv_, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
std::optional<at::Tensor>&
|
||||
alibi_slopes_, // num_heads or batch_size x num_heads
|
||||
attn_bias_, // batch_size x num_heads x seqlen_q x seqlen_k
|
||||
bool bias_requires_grad,
|
||||
std::optional<at::Tensor>& grad_bias,
|
||||
const float p_dropout, // probability to drop
|
||||
const float softmax_scale,
|
||||
const bool is_causal,
|
||||
|
|
@ -212,7 +213,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd_ck(
|
|||
const at::Tensor philox_seed,
|
||||
const at::Tensor philox_offset);
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd_ck(
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd_ck(
|
||||
const at::Tensor& dout, // total_q x num_heads, x head_size
|
||||
const at::Tensor&
|
||||
q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
|
|
@ -230,7 +231,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd_ck(
|
|||
dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor& cu_seqlens_q, // b+1
|
||||
const at::Tensor& cu_seqlens_k, // b+1
|
||||
std::optional<at::Tensor>& alibi_slopes_, // num_heads or b x num_heads
|
||||
std::optional<at::Tensor>& attn_bias_, // num_heads or b x num_heads
|
||||
bool bias_requires_grad,
|
||||
std::optional<at::Tensor>& grad_bias,
|
||||
const int max_seqlen_q,
|
||||
const int max_seqlen_k, // max sequence length to choose the kernel
|
||||
const float p_dropout, // probability to drop
|
||||
|
|
@ -273,19 +276,20 @@ mha_fwd(
|
|||
#if defined(USE_CK_FLASH_ATTENTION)
|
||||
if (at::globalContext().getROCmFAPreferredBackend() ==
|
||||
at::ROCmFABackend::Ck) {
|
||||
std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
|
||||
return mha_fwd_ck(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out_,
|
||||
alibi_slopes_,
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
return_softmax,
|
||||
gen_);
|
||||
gen_,
|
||||
dummy_attn_bias); // Not used in flash attention
|
||||
} else {
|
||||
return mha_fwd_aot(
|
||||
q,
|
||||
|
|
@ -358,6 +362,7 @@ mha_varlen_fwd(
|
|||
#if defined(USE_CK_FLASH_ATTENTION)
|
||||
if (at::globalContext().getROCmFAPreferredBackend() ==
|
||||
at::ROCmFABackend::Ck) {
|
||||
std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
|
||||
return mha_varlen_fwd_ck(
|
||||
q,
|
||||
k,
|
||||
|
|
@ -366,7 +371,6 @@ mha_varlen_fwd(
|
|||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
seqused_k,
|
||||
alibi_slopes_,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
p_dropout,
|
||||
|
|
@ -376,7 +380,8 @@ mha_varlen_fwd(
|
|||
window_size_left,
|
||||
window_size_right,
|
||||
return_softmax,
|
||||
gen_);
|
||||
gen_,
|
||||
dummy_attn_bias); // Not used in flash attention
|
||||
} else {
|
||||
return mha_varlen_fwd_aot(
|
||||
q,
|
||||
|
|
@ -450,25 +455,34 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd(
|
|||
#if defined(USE_CK_FLASH_ATTENTION)
|
||||
if (at::globalContext().getROCmFAPreferredBackend() ==
|
||||
at::ROCmFABackend::Ck) {
|
||||
return mha_bwd_ck(
|
||||
dout,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
softmax_lse,
|
||||
dq_,
|
||||
dk_,
|
||||
dv_,
|
||||
alibi_slopes_,
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
deterministic,
|
||||
philox_seed,
|
||||
philox_offset);
|
||||
std::optional<at::Tensor> non_null_dbias = std::nullopt;
|
||||
auto[dQuery,
|
||||
dKey,
|
||||
dValue,
|
||||
dSoftmax,
|
||||
dBias] = mha_bwd_ck(
|
||||
dout,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
softmax_lse,
|
||||
dq_,
|
||||
dk_,
|
||||
dv_,
|
||||
alibi_slopes_,
|
||||
false, // bias_requires_grad
|
||||
non_null_dbias,
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
deterministic,
|
||||
philox_seed,
|
||||
philox_offset);
|
||||
// for FA return [dQ, dV, dK, dSoftmax]
|
||||
return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax));
|
||||
} else {
|
||||
return mha_bwd_aot(
|
||||
dout,
|
||||
|
|
@ -551,30 +565,39 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd
|
|||
#if defined(USE_CK_FLASH_ATTENTION)
|
||||
if (at::globalContext().getROCmFAPreferredBackend() ==
|
||||
at::ROCmFABackend::Ck) {
|
||||
return mha_varlen_bwd_ck(
|
||||
dout,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
softmax_lse,
|
||||
dq_,
|
||||
dk_,
|
||||
dv_,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
alibi_slopes_,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
zero_tensors,
|
||||
is_causal,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
deterministic,
|
||||
philox_seed,
|
||||
philox_offset);
|
||||
std::optional<at::Tensor> non_null_dbias = std::nullopt;
|
||||
auto[dQuery,
|
||||
dKey,
|
||||
dValue,
|
||||
dSoftmax,
|
||||
dBias] = mha_varlen_bwd_ck(
|
||||
dout,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
softmax_lse,
|
||||
dq_,
|
||||
dk_,
|
||||
dv_,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
alibi_slopes_,
|
||||
false, // bias_requires_grad
|
||||
non_null_dbias,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
zero_tensors,
|
||||
is_causal,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
deterministic,
|
||||
philox_seed,
|
||||
philox_offset);
|
||||
// for FA return [dQ, dV, dK, dSoftmax]
|
||||
return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax));
|
||||
} else {
|
||||
return mha_varlen_bwd_aot(
|
||||
dout,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user