[ROCm] CK Memory-Efficient Attention (attention bias support) (#147778)

Implements CK as the backend for memory efficient attention with a couple caveats:

- Still enabled via `torch.backends.cuda.preferred_rocm_fa_library("ck")
- Does NOT support Nested Tensors

Using the mem_eff path allows us to use attention bias with a CK sdpa backend

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147778
Approved by: https://github.com/houseroad
This commit is contained in:
Andy Lugo 2025-03-11 19:02:56 +00:00 committed by PyTorch MergeBot
parent a1cb67b69e
commit 4d10da731b
13 changed files with 848 additions and 2186 deletions

View File

@ -88,6 +88,7 @@
#include <ATen/native/transformers/hip/aotriton_adapter.h>
#include <aotriton/flash.h>
#include <aotriton/runtime.h>
#include <ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h>
#endif
#endif
@ -1243,104 +1244,146 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
#ifdef USE_ROCM
// ROCM Implementation
auto ret = aotriton::v2::flash::check_gpu(stream);
if (hipSuccess != ret) {
TORCH_CHECK(false,
"[AOTriton] Accelerated SDPA only supports MI200/MI300X/7900XTX/9070XT GPUs"
" (gfx90a/gfx942/gfx1100/gfx1201)")
}
// AOTriton may accept aligned on logsumexp tensor in the future for better
// performance, but for now it requires compact logsumexp tensor, even if
// compute_logsumexp is false
constexpr int kAlignLSE = 1;
// Need this in both aot and CK case
const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float();
res = at::empty({B, M, num_heads, Kv}, query.options());
at::Tensor softmax_lse;
logsumexp = at::empty(
if(at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) {
#if defined(USE_CK_FLASH_ATTENTION)
std::optional<Tensor> out(res);
std::optional<Tensor> seqused_k = std::nullopt;
std::optional<Tensor> alibi_slopes = std::nullopt;
auto
[out_,
q,
k,
v,
lse,
seed_t,
offset_t,
p] =
pytorch_flash::mem_eff_forward_ck(
query,
key,
value,
dropout_p,
false, // return dropout_randval
custom_mask_type == 0 ? false : true, // is_causal
softmax_scale,
bias,
out,
std::nullopt, // cu_seqlens_q
std::nullopt, // cu_seqlens_k
seqstart_q,
seqstart_k,
std::nullopt, // gen_
seqused_k); // seqused_k_
logsumexp = lse;
#else
TORCH_CHECK(false, "Attempting to use CK mem_eff_forward backend in a build that has not built CK");
#endif
} else { // use aotriton
auto ret = aotriton::v2::flash::check_gpu(stream);
if (hipSuccess != ret) {
TORCH_CHECK(false,
"[AOTriton] Accelerated SDPA only supports MI200/MI300X/Navi31 GPUs"
" (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)")
}
// AOTriton may accept aligned on logsumexp tensor in the future for better
// performance, but for now it requires compact logsumexp tensor, even if
// compute_logsumexp is false
constexpr int kAlignLSE = 1;
res = at::empty({B, M, num_heads, Kv}, query.options());
at::Tensor softmax_lse;
logsumexp = at::empty(
{ B, num_heads, compute_logsumexp ? max_seqlen_q : 0},
query.options().dtype(at::ScalarType::Float));
if (compute_logsumexp) {
softmax_lse = logsumexp.view({B * num_heads, max_seqlen_q});
}
at::Tensor q_t = query.transpose(1, 2);
at::Tensor k_t = key.transpose(1, 2);
at::Tensor v_t = value.transpose(1, 2);
at::Tensor output_t = res.transpose(1, 2);
bool is_causal;
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
is_causal = true;
} else if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
is_causal = false;
} else {
TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now");
}
if (compute_logsumexp) {
softmax_lse = logsumexp.view({B * num_heads, max_seqlen_q});
}
at::Tensor q_t = query.transpose(1, 2);
at::Tensor k_t = key.transpose(1, 2);
at::Tensor v_t = value.transpose(1, 2);
at::Tensor output_t = res.transpose(1, 2);
bool is_causal;
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
is_causal = true;
} else if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
is_causal = false;
} else {
TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now");
}
const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float();
at::Tensor atomic_counter;
if (is_causal) {
atomic_counter = at::zeros({1}, query.options().dtype(at::kInt));
}
at::Tensor atomic_counter;
if (is_causal) {
atomic_counter = at::zeros({1}, query.options().dtype(at::kInt));
}
using aotriton::v2::flash::attn_fwd;
using aotriton::v2::flash::attn_fwd_compact_varlen;
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
using sdp::aotriton_adapter::mk_philoxtensor;
using sdp::aotriton_adapter::mk_atomictensor;
aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, aotriton::DType::kFloat16);
aotriton::TensorView<2> empty_t2(0, {0, 0}, {0, 0}, aotriton::DType::kFloat32);
at::Tensor softmax_fa_t = at::empty({ 0, 0, 0, 0 }, query.options());
const bool use_philox_state = in_capture_stream;
auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t);
auto offset1 = use_philox_state ? mk_philoxtensor(philox_state.offset_.ptr) : mk_aoscalartensor(offset_t);
auto offset2 = use_philox_state ? philox_state.offset_intragraph_ : 0;
auto seed_output = mk_philoxtensor(use_philox_state ? seed_t.data_ptr<int64_t>() : nullptr);
auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr<int64_t>() : nullptr);
auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr<int32_t>() : nullptr);
hipError_t err; // TODO: Error handling
if (seqstart_q.has_value()) {
// varlen aka nested tensor
err = attn_fwd_compact_varlen(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4,
mk_aotensor<1>(seqstart_q.value(), "cu_seqlens_q"),
mk_aotensor<1>(seqstart_k.value(), "cu_seqlens_k"),
max_seqlen_q,
max_seqlen_k,
softmax_scale,
compute_logsumexp ? mk_aotensor<2>(softmax_lse, "M") : empty_t2,
mk_aotensor(output_t, "Out"),
dropout_p,
seed,
offset1,
offset2,
seed_output,
offset_output,
mk_aotensor(softmax_fa_t, "encoded_softmax"),
is_causal,
persistent_counter,
stream);
} else {
err = attn_fwd(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4,
softmax_scale,
compute_logsumexp ? mk_aotensor<2>(softmax_lse, "M") : empty_t2,
mk_aotensor(output_t, "Out"),
dropout_p,
seed,
offset1,
offset2,
seed_output,
offset_output,
mk_aotensor(softmax_fa_t, "encoded_softmax"),
is_causal,
persistent_counter,
stream);
}
using aotriton::v2::flash::attn_fwd;
using aotriton::v2::flash::attn_fwd_compact_varlen;
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
using sdp::aotriton_adapter::mk_philoxtensor;
using sdp::aotriton_adapter::mk_atomictensor;
aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, aotriton::DType::kFloat16);
aotriton::TensorView<2> empty_t2(0, {0, 0}, {0, 0}, aotriton::DType::kFloat32);
at::Tensor softmax_fa_t = at::empty({ 0, 0, 0, 0 }, query.options());
const bool use_philox_state = in_capture_stream;
auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t);
auto offset1 = use_philox_state ? mk_philoxtensor(philox_state.offset_.ptr) : mk_aoscalartensor(offset_t);
auto offset2 = use_philox_state ? philox_state.offset_intragraph_ : 0;
auto seed_output = mk_philoxtensor(use_philox_state ? seed_t.data_ptr<int64_t>() : nullptr);
auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr<int64_t>() : nullptr);
auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr<int32_t>() : nullptr);
hipError_t err; // TODO: Error handling
if (seqstart_q.has_value()) {
// varlen aka nested tensor
err = attn_fwd_compact_varlen(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4,
mk_aotensor<1>(seqstart_q.value(), "cu_seqlens_q"),
mk_aotensor<1>(seqstart_k.value(), "cu_seqlens_k"),
max_seqlen_q,
max_seqlen_k,
softmax_scale,
compute_logsumexp ? mk_aotensor<2>(softmax_lse, "M") : empty_t2,
mk_aotensor(output_t, "Out"),
dropout_p,
seed,
offset1,
offset2,
seed_output,
offset_output,
mk_aotensor(softmax_fa_t, "encoded_softmax"),
is_causal,
persistent_counter,
stream);
} else {
err = attn_fwd(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4,
softmax_scale,
compute_logsumexp ? mk_aotensor<2>(softmax_lse, "M") : empty_t2,
mk_aotensor(output_t, "Out"),
dropout_p,
seed,
offset1,
offset2,
seed_output,
offset_output,
mk_aotensor(softmax_fa_t, "encoded_softmax"),
is_causal,
persistent_counter,
stream);
}
} // CK BACKEND
#else
// CUDA Implementation
cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index());

View File

@ -47,6 +47,7 @@
#include <ATen/native/transformers/hip/aotriton_adapter.h>
#include <aotriton/flash.h>
#include <aotriton/runtime.h>
#include <ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h>
#endif
#endif
@ -412,94 +413,133 @@ _efficient_attention_backward(
#ifdef USE_ROCM
// ROCM Implementation
TORCH_CHECK(!num_splits_key.has_value(),
if(at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck)
{
#if defined(USE_CK_FLASH_ATTENTION)
const auto my_softmax_scale = sdp::calculate_scale(query, scale).expect_float();
// Store grad_bias in optional
std::optional<at::Tensor> opt_grad_bias = grad_bias;
auto
[dQ,
dK,
dV,
dBias] =
pytorch_flash::mem_eff_backward_ck(
grad_out,
query,
key,
value,
out,
logsumexp,
grad_q,
grad_k,
grad_v,
bias,
bias_requires_grad,
opt_grad_bias,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
float(dropout_p),
my_softmax_scale,
custom_mask_type == 0 ? false : true, // is_causal
false, // deterministic
false, // zero_tensors
philox_seed,
philox_offset);
grad_bias = dBias;
#else
TORCH_CHECK(false, "Attempting to use CK mem_eff_backward backend in a build that has not built CK");
#endif
} else {
TORCH_CHECK(!num_splits_key.has_value(),
"ROCM does not support num_split_keys in _efficient_attention_forward");
TORCH_CHECK(!window_size.has_value(),
TORCH_CHECK(!window_size.has_value(),
"ROCM does not support window_size in _efficient_attention_forward");
auto ret = aotriton::v2::flash::check_gpu(stream);
if (hipSuccess != ret) {
TORCH_CHECK(false,
auto ret = aotriton::v2::flash::check_gpu(stream);
if (hipSuccess != ret) {
TORCH_CHECK(false,
"[AOTriton] Accelerated SDPA only supports MI200/MI300X/7900XTX/9070XT GPUs"
" (gfx90a/gfx942/gfx1100/gfx1201)")
}
const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float();
bool is_causal;
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
is_causal = true;
} else if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
is_causal = false;
} else {
TORCH_CHECK(false, "[_efficient_attention_backward] Unsupported mask type in AOTriton, for now");
}
at::Tensor q_t = query.permute({0,2,1,3});
at::Tensor k_t = key.permute({0,2,1,3});
at::Tensor v_t = value.permute({0,2,1,3});
at::Tensor out_t = out.permute({0,2,1,3});
at::Tensor dq_t = grad_q.permute({0,2,1,3});
at::Tensor dk_t = grad_k.permute({0,2,1,3});
at::Tensor dv_t = grad_v.permute({0,2,1,3});
at::Tensor dout_t = grad_out.permute({0,2,1,3});
at::Tensor softmax_lse = logsumexp.view({B * nH, max_seqlen_q});
hipError_t err;
using aotriton::v2::flash::attn_bwd;
using aotriton::v2::flash::attn_bwd_fused;
using aotriton::v2::flash::attn_bwd_compact_varlen;
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
using sdp::aotriton_adapter::cast_dtype;
aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, cast_dtype(query.dtype()));
if (cu_seqlens_q.has_value()) {
at::Tensor delta = at::empty_like(softmax_lse).contiguous();
// varlen aka Nested tensor
err = attn_bwd_compact_varlen(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
mk_aotensor<1>(cu_seqlens_q.value(), "cu_seqlens_q"),
mk_aotensor<1>(cu_seqlens_k.value(), "cu_seqlens_k"),
max_seqlen_q,
max_seqlen_k,
bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4,
softmax_scale,
mk_aotensor(out_t, "out"),
mk_aotensor(dout_t, "dout"),
mk_aotensor(dq_t, "dq"),
mk_aotensor(dk_t, "dk"),
mk_aotensor(dv_t, "dv"),
bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4,
mk_aotensor<2>(softmax_lse, "L"),
mk_aotensor<2>(delta, "delta"),
float(dropout_p),
mk_aoscalartensor(philox_seed),
mk_aoscalartensor(philox_offset),
0,
is_causal,
stream);
} else {
auto d_head = Kv;
bool use_fused_bwd = d_head <= 192 && d_head * max_seqlen_q < 64 * 512;
if (use_fused_bwd) {
err = attn_bwd_fused(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4,
softmax_scale,
mk_aotensor(out_t, "out"),
mk_aotensor(dout_t, "dout"),
mk_aotensor(dq_t, "dq"),
mk_aotensor(dk_t, "dk"),
mk_aotensor(dv_t, "dv"),
bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4,
mk_aotensor<2>(softmax_lse, "L"),
float(dropout_p),
mk_aoscalartensor(philox_seed),
mk_aoscalartensor(philox_offset),
0,
is_causal,
stream);
}
const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float();
bool is_causal;
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
is_causal = true;
} else if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
is_causal = false;
} else {
TORCH_CHECK(false, "[_efficient_attention_backward] Unsupported mask type in AOTriton, for now");
}
at::Tensor q_t = query.permute({0,2,1,3});
at::Tensor k_t = key.permute({0,2,1,3});
at::Tensor v_t = value.permute({0,2,1,3});
at::Tensor out_t = out.permute({0,2,1,3});
at::Tensor dq_t = grad_q.permute({0,2,1,3});
at::Tensor dk_t = grad_k.permute({0,2,1,3});
at::Tensor dv_t = grad_v.permute({0,2,1,3});
at::Tensor dout_t = grad_out.permute({0,2,1,3});
at::Tensor softmax_lse = logsumexp.view({B * nH, max_seqlen_q});
hipError_t err;
using aotriton::v2::flash::attn_bwd;
using aotriton::v2::flash::attn_bwd_fused;
using aotriton::v2::flash::attn_bwd_compact_varlen;
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
using sdp::aotriton_adapter::cast_dtype;
aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, cast_dtype(query.dtype()));
if (cu_seqlens_q.has_value()) {
at::Tensor delta = at::empty_like(softmax_lse).contiguous();
err = attn_bwd(mk_aotensor(q_t, "q"),
// varlen aka Nested tensor
err = attn_bwd_compact_varlen(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
mk_aotensor<1>(cu_seqlens_q.value(), "cu_seqlens_q"),
mk_aotensor<1>(cu_seqlens_k.value(), "cu_seqlens_k"),
max_seqlen_q,
max_seqlen_k,
bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4,
softmax_scale,
mk_aotensor(out_t, "out"),
mk_aotensor(dout_t, "dout"),
mk_aotensor(dq_t, "dq"),
mk_aotensor(dk_t, "dk"),
mk_aotensor(dv_t, "dv"),
bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4,
mk_aotensor<2>(softmax_lse, "L"),
mk_aotensor<2>(delta, "delta"),
float(dropout_p),
mk_aoscalartensor(philox_seed),
mk_aoscalartensor(philox_offset),
0,
is_causal,
stream);
} else { // cu_seqlens.has_value
auto d_head = Kv;
bool use_fused_bwd = d_head <= 192 && d_head * max_seqlen_q < 64 * 512;
if (use_fused_bwd) {
err = attn_bwd_fused(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4,
softmax_scale,
mk_aotensor(out_t, "out"),
mk_aotensor(dout_t, "dout"),
mk_aotensor(dq_t, "dq"),
mk_aotensor(dk_t, "dk"),
mk_aotensor(dv_t, "dv"),
bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4,
mk_aotensor<2>(softmax_lse, "L"),
float(dropout_p),
mk_aoscalartensor(philox_seed),
mk_aoscalartensor(philox_offset),
0,
is_causal,
stream);
} else {
at::Tensor delta = at::empty_like(softmax_lse).contiguous();
err = attn_bwd(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4,
@ -518,9 +558,10 @@ _efficient_attention_backward(
0,
is_causal,
stream);
}
}
#else
} //used_fused_bwd
} // cuseqlen.has_value
} // Use CK
#else // USE_CUDA
at::Tensor workspace;
cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index());
const int computeCapability = p->major * 10 + p->minor;

View File

@ -787,4 +787,4 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
}
} // namespace pytorch_flash
#endif
#endif // USE_FLASH_ATTENTION

View File

@ -0,0 +1,120 @@
#include <ATen/native/transformers/hip/flash_attn/flash_common_hip.hpp>
#include <ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h>
#if defined(USE_CK_FLASH_ATTENTION)
namespace pytorch_flash {
std::tuple<
at::Tensor, // dQ
at::Tensor, // dK
at::Tensor, // dV
at::Tensor> // dBias
mem_eff_backward_ck(
const at::Tensor &dout,
const at::Tensor &q,
const at::Tensor &k,
const at::Tensor &v,
const at::Tensor &out,
const at::Tensor &softmax_lse,
const at::Tensor &dq_,
const at::Tensor &dk_,
const at::Tensor &dv_,
std::optional<at::Tensor> &attn_bias,
bool bias_requires_grad,
std::optional<at::Tensor> &grad_bias,
std::optional<at::Tensor> &cu_seqlens_q,
std::optional<at::Tensor> &cu_seqlens_k,
int max_seqlen_q,
int max_seqlen_k,
float p_dropout,
float scale,
bool is_causal,
bool deterministic,
bool zero_tensors,
at::Tensor philox_seed,
at::Tensor philox_offset)
{
const int non_null_window_left = -1;
const int non_null_window_right = -1;
std::optional<at::Tensor> opt_dQ, opt_dK, opt_dV;
opt_dQ = dq_;
opt_dK = dk_;
opt_dV = dv_;
if(!cu_seqlens_q.has_value()) {
auto
[dQ,
dK,
dV,
softmax_d,
dBias] =
mha_bwd_ck(
dout,
q,
k,
v,
out,
softmax_lse,
opt_dQ,
opt_dK,
opt_dV,
attn_bias,
bias_requires_grad,
grad_bias,
p_dropout,
scale,
is_causal,
non_null_window_left,
non_null_window_right,
deterministic,
philox_seed,
philox_offset);
return std::make_tuple(std::move(dQ), std::move(dK), std::move(dV), std::move(dBias));
} else {
// cu_seqlens only has a value in the nested tensor path which CK does not support
TORCH_CHECK(false, "Nested Tensors not supported with CK backend.");
return std::make_tuple(at::Tensor{}, at::Tensor{}, at::Tensor{}, at::Tensor{});
// TODO: Fix nested tensor(varlen) path
/*
auto
[dQ,
dK,
dV,
softmax_d,
dBias] =
mha_varlen_bwd_ck(
dout,
q,
k,
v,
out,
softmax_lse,
opt_dQ,
opt_dK,
opt_dV,
cu_seqlens_q.value(),
cu_seqlens_k.value(),
attn_bias,
bias_requires_grad,
grad_bias,
max_seqlen_q,
max_seqlen_k,
p_dropout,
scale,
zero_tensors,
is_causal,
non_null_window_left,
non_null_window_right,
deterministic,
philox_seed,
philox_offset);
return std::make_tuple(std::move(dQ), std::move(dK), std::move(dV), std::move(dBias));
*/
}
return std::make_tuple(at::Tensor{}, at::Tensor{}, at::Tensor{}, at::Tensor{});
}
} // namespace pytorch_flash
#endif // USE_CK_FLASH_ATTENTION

View File

@ -0,0 +1,67 @@
#pragma once
#include <cstddef>
#include <ATen/core/Tensor.h>
#if defined(USE_CK_FLASH_ATTENTION)
namespace pytorch_flash {
std::tuple<
at::Tensor, // output
at::Tensor, // q
at::Tensor, // k
at::Tensor, // v
at::Tensor, // lse
at::Tensor, // seed
at::Tensor, // offset
at::Tensor> // dropout randval
mem_eff_forward_ck(
const at::Tensor& q,
const at::Tensor& k,
const at::Tensor& v,
float p_dropout,
bool return_dropout_randval,
std::optional<bool> is_causal,
std::optional<float> scale,
const std::optional<at::Tensor>& attn_bias_,
std::optional<at::Tensor>& out_,
const std::optional<at::Tensor>& cu_seqlens_q,
const std::optional<at::Tensor>& cu_seqlens_k,
const std::optional<at::Tensor>& seqstart_q,
const std::optional<at::Tensor>& seqstart_k,
std::optional<at::Generator> gen_,
std::optional<at::Tensor>& seqused_k_
);
std::tuple<
at::Tensor, // dQ
at::Tensor, // dK
at::Tensor, // dV
at::Tensor> // dBias
mem_eff_backward_ck(
const at::Tensor &dout,
const at::Tensor &q,
const at::Tensor &k,
const at::Tensor &v,
const at::Tensor &out,
const at::Tensor &softmax_lse,
const at::Tensor &dq_,
const at::Tensor &dk_,
const at::Tensor &dv_,
std::optional<at::Tensor> &attn_bias,
bool bias_requires_grad,
std::optional<at::Tensor> &grad_bias,
std::optional<at::Tensor> &cu_seqlens_q,
std::optional<at::Tensor> &cu_seqlens_k,
int max_seqlen_q,
int max_seqlen_k,
float p_dropout,
float scale,
bool is_causal,
bool deterministic,
bool zero_tensors,
const at::Tensor philox_seed,
const at::Tensor philox_offset);
} // namespace pytorch_flash
#endif // USE_CK_FLASH_ATTENTION

View File

@ -0,0 +1,96 @@
#include <ATen/native/transformers/hip/flash_attn/flash_common_hip.hpp>
#include <ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h>
#if defined(USE_CK_FLASH_ATTENTION)
namespace pytorch_flash {
std::tuple<
at::Tensor, // output
at::Tensor, // q
at::Tensor, // k
at::Tensor, // v
at::Tensor, // lse
at::Tensor, // seed
at::Tensor, // offset
at::Tensor> // dropout randval
mem_eff_forward_ck(
const at::Tensor& q,
const at::Tensor& k,
const at::Tensor& v,
float p_dropout,
bool return_dropout_randval,
std::optional<bool> is_causal,
std::optional<float> scale,
const std::optional<at::Tensor>& attn_bias_,
std::optional<at::Tensor>& out_,
const std::optional<at::Tensor>& cu_seqlens_q,
const std::optional<at::Tensor>& cu_seqlens_k,
const std::optional<at::Tensor>& seqstart_q,
const std::optional<at::Tensor>& seqstart_k,
std::optional<at::Generator> gen_,
std::optional<at::Tensor>& seqused_k_) {
const int non_null_window_left = -1;
const int non_null_window_right = -1;
TORCH_CHECK(
cu_seqlens_q.has_value() == cu_seqlens_k.has_value(),
"cu_seqlens_q and cu_seqlens_k must be both set or both not set");
if(!seqstart_q.has_value()){
return mha_fwd_ck(
q, // q
k, // k
v, // v
out_, // opt(out_)
p_dropout, // p_dropout
scale.value(), // opt(softmax_scale)
is_causal.value(), // opt(is_causal)
non_null_window_left, // window_size_left
non_null_window_right, // window_size_right
false, // return_softmax/return_debug_mask
gen_, // gen
attn_bias_); // attn_bias
} else {
// seqstart_q is only set in nested tensor path which CK does not support
TORCH_CHECK(false, "Nested Tensors not supported with CK backend.");
return std::make_tuple(at::Tensor{},
at::Tensor{},
at::Tensor{},
at::Tensor{},
at::Tensor{},
at::Tensor{},
at::Tensor{},
at::Tensor{});
// TODO: Fix nested tensor(varlen) path
/*
// max sequence lengths are now at T.size(1) since q,k,v were all transposed
// in _scaled_dot_product_efficient_attention_cuda
const int64_t max_seqlen_q = q.size(1);
const int64_t max_seqlen_k = k.size(1);
return mha_varlen_fwd_ck(
q, // q
k, // k
v, // v
out_, // opt(out)
seqstart_q.value(), // cu_seqlens_q
seqstart_k.value(), // cu_seqlens_k
seqused_k_, // opt(seqused_k)
max_seqlen_q, // max_seqlen_q
max_seqlen_k, // max_seqlen_k
p_dropout, // p_dropout
scale.value(), // softmax_scale
false, // zero_tensors
is_causal.value(), // is_causal
non_null_window_left, // window_size_left
non_null_window_right, // window_size_right
false, // return_softmax/return_debug_mask
gen_, // gen
attn_bias_); // attn_bias
*/
}
}
} // namespace pytorch_flash
#endif // USE_CK_FLASH_ATTENTION

View File

@ -12,16 +12,17 @@ fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask,
std::string dtype,
int head_size,
bool has_dropout,
bool enable_alibi,
bool deterministic)
bool enable_bias,
bool deterministic,
bool bias_requires_grad)
{
return fmha_bwd_traits{head_size,
head_size,
dtype,
false, // is_group_mode
mask.type,
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
false, // has_dbias
enable_bias ? bias_enum::elementwise_bias : bias_enum::no_bias,
bias_requires_grad, // has_dbias
has_dropout,
false, // s_randval
deterministic};
@ -39,7 +40,9 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
const at::Tensor q,
const at::Tensor k,
const at::Tensor v,
std::optional<at::Tensor> &alibi_slopes_,
std::optional<at::Tensor> &attn_bias_,
bool bias_requires_grad,
std::optional<at::Tensor> &grad_bias,
const at::Tensor out,
const at::Tensor softmax_lse,
const at::Tensor dout,
@ -82,7 +85,6 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
ck_tile::index_t nhead_stride_do = dout.stride(2);
// d: (batch_size, nheads, seqlen_q)
// CK assume d share the same stride with lse
// dq: (batch_size, seqlen_q, nheads, hdim)
ck_tile::index_t batch_stride_dq = dq.stride(0);
@ -105,86 +107,103 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
ck_tile::index_t stride_dq_acc = dq_acc.stride(2);
ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(3);
float p_undrop = 1.0 - p_dropout;
void *alibi_slopes_ptr = nullptr;
ck_tile::index_t stride_alibi_slopes = 0;
if (alibi_slopes_.has_value()) {
auto alibi_slopes = alibi_slopes_.value();
CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
TORCH_CHECK(alibi_slopes.sizes() == at::IntArrayRef({h}) || alibi_slopes.sizes() == at::IntArrayRef({b, h}));
alibi_slopes_ptr = alibi_slopes.data_ptr();
// alibi_slopes:(batch_size, nheads) or (nhead)
stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
// bias: (batch_size, nheads, seqlen_q, seqlen_k)
void *attn_bias_ptr = nullptr;
ck_tile::index_t nhead_stride_bias = 0;
ck_tile::index_t batch_stride_bias = 0;
ck_tile::index_t stride_attn_bias = 0;
if (attn_bias_.has_value()) {
auto a_b = attn_bias_.value();
CHECK_DEVICE(a_b);
TORCH_CHECK(a_b.stride(-1) == 1, "Attention bias tensor must have contiguous last dimension");
attn_bias_ptr = a_b.data_ptr();
stride_attn_bias = a_b.stride(2);
nhead_stride_bias = a_b.stride(1);
batch_stride_bias = a_b.stride(0);
}
// dbias: (batch_size, nheads, seqlen_q, seqlen_k)
void *dbias_ptr = nullptr;
ck_tile::index_t stride_dbias = 0;
ck_tile::index_t nhead_stride_dbias = 0;
ck_tile::index_t batch_stride_dbias = 0;
if(bias_requires_grad) {
// If bias_requires_grad is true, grad_bias is guaranteed to have a value via line 270
//grad_bias
auto dbias = grad_bias.value();
dbias_ptr = dbias.data_ptr();
stride_dbias = dbias.stride(2);
nhead_stride_dbias = dbias.stride(1);
batch_stride_dbias = dbias.stride(0);
}
float p_undrop = 1.0 - p_dropout;
return fmha_bwd_args{q.data_ptr(),
k.data_ptr(),
v.data_ptr(),
alibi_slopes_ptr, // bias
attn_bias_ptr, // bias
out.data_ptr(),
softmax_lse.data_ptr(),
dout.data_ptr(),
d.data_ptr(),
nullptr, // rand_val
nullptr, // rand_val
dq.data_ptr(),
dk.data_ptr(),
dv.data_ptr(),
nullptr, // dbias
dq_acc.data_ptr(), // dq_acc
nullptr, // seqstart_q
nullptr, // seqstart_k
nullptr, // seqlen_k_ptr
dbias_ptr, // dbias
dq_acc.data_ptr(), // dq_acc
nullptr, // seqstart_q
nullptr, // seqstart_k
nullptr, // seqlen_k_ptr
seqlen_q,
seqlen_k,
b,
seqlen_q, // max_seqlen_q
seqlen_k, // max_seqlen_k
hdim, // hdim_q
hdim, // hdim_v
h, // nhead
h_k, // nhead_k
seqlen_q, // max_seqlen_q
seqlen_k, // max_seqlen_k
hdim, // hdim_q
hdim, // hdim_v
h, // nhead
h_k, // nhead_k
softmax_scale,
stride_q,
stride_k,
stride_v,
stride_alibi_slopes,
stride_attn_bias,
stride_o,
0, // stride_randval
0, // stride_randval
stride_do,
stride_dq_acc,
stride_dq,
stride_dk,
stride_dv,
0, // stride_dbias, FA without bias
stride_dbias, // stride_dbias
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
0, // nhead_stride_bias, FA without bias
nhead_stride_bias, // nhead_stride_bias
nhead_stride_o,
0, // nhead_stride_randval
0, // nhead_stride_randval
nhead_stride_do,
nhead_stride_lse,
nhead_stride_dq_acc,
nhead_stride_dq,
nhead_stride_dk,
nhead_stride_dv,
0, // nhead_stride_dbias, FA without dbias
nhead_stride_dbias, // nhead_stride_dbias
batch_stride_q,
batch_stride_k,
batch_stride_v,
0 , // batch_stride_bias, FA without bias
batch_stride_bias, // batch_stride_bias
batch_stride_o,
0, // batch_stride_randval
0, // batch_stride_randval
batch_stride_do,
batch_stride_lse,
batch_stride_dq_acc,
batch_stride_dq,
batch_stride_dk,
batch_stride_dv,
0 , // batch_stride_dbias, FA without dbias
batch_stride_dbias, // batch_stride_dbias
split_stride_dq_acc,
mask.left,
mask.right,
@ -193,8 +212,7 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
p_undrop,
drop_seed_offset};
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
@ -204,7 +222,9 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x
std::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
std::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
std::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
std::optional<at::Tensor> &attn_bias_, // num_heads or batch_size x num_heads
bool bias_requires_grad,
std::optional<at::Tensor> &grad_bias,
const float p_dropout, // probability to drop
const float softmax_scale,
const bool is_causal,
@ -242,6 +262,9 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x
TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
TORCH_CHECK((bias_requires_grad && grad_bias.has_value()) || (!bias_requires_grad),
"If bias_requires_grad is set, grad_bias must have a value");
const auto sizes = q.sizes();
const int batch_size = sizes[0];
@ -354,7 +377,13 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x
ck_tile::stream_config stream_config{stream};
dq.zero_(); // ck use atomic operation on dq
auto traits =
get_ck_fmha_bwd_traits(mask, q_dtype_str, head_size_8x, is_dropout, alibi_slopes_.has_value(), deterministic);
get_ck_fmha_bwd_traits(mask,
q_dtype_str,
head_size_8x,
is_dropout,
attn_bias_.has_value(),
deterministic,
bias_requires_grad);
auto args =
get_ck_fmha_bwd_args(
@ -368,7 +397,9 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x
q,
k,
v,
alibi_slopes_,
attn_bias_,
bias_requires_grad,
grad_bias,
out,
softmax_lse,
dout_padded,
@ -400,6 +431,14 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x
dv = dv.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)});
}
return { dq, dk, dv, softmax_d };
at::Tensor dbias;
if(bias_requires_grad) {
dbias = grad_bias.value();
} else {
dbias = at::empty({batch_size, num_heads, seqlen_q, seqlen_k}, q.options());
}
return { dq, dk, dv, softmax_d, dbias };
}
} // namespace pytorch_flash

View File

@ -7,7 +7,6 @@
#include <mask.hpp>
namespace pytorch_flash {
@ -16,7 +15,7 @@ fmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask,
int head_size,
bool has_dropout,
bool has_lse,
bool enable_alibi)
bool enable_bias)
{
return fmha_fwd_traits{head_size,
head_size,
@ -24,7 +23,7 @@ fmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask,
false, // is_group_mode
true, // is_v_rowmajor
mask.type,
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
enable_bias ? bias_enum::elementwise_bias : bias_enum::no_bias,
has_lse,
has_dropout,
false}; // do_fp8_static_quant
@ -44,7 +43,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
const at::Tensor q,
const at::Tensor k,
const at::Tensor v,
std::optional<at::Tensor> &alibi_slopes_,
std::optional<at::Tensor> &attn_bias_,
at::Tensor out,
at::Tensor softmax_lse,
at::Tensor dropout_randval,
@ -57,7 +56,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
// v: (batch_size, seqlen_k, nheads_k, d)
// o: (batch_size, seqlen_q, nheads, d)
// alibi_slopes:(batch_size, nheads) or (nhead)
// attn_bias: (batch_size, nheads, seqlen_q, seqlen_k)
// lse: (batch_size, nheads, seqlen_q)
// randval: (batch_size, nheads, seqlen_q, seqlen_k)
@ -82,56 +81,58 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
ck_tile::index_t batch_stride_lse = has_lse ? softmax_lse.stride(0) : 0;
ck_tile::index_t batch_stride_randval = has_dropout_randval ? dropout_randval.stride(0) : 0;
void *alibi_slopes_ptr = nullptr;
ck_tile::index_t stride_alibi_slopes = 0;
if (alibi_slopes_.has_value()) {
auto alibi_slopes = alibi_slopes_.value();
CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
TORCH_CHECK(alibi_slopes.sizes() == at::IntArrayRef({h}) || alibi_slopes.sizes() == at::IntArrayRef({b, h}));
alibi_slopes_ptr = alibi_slopes.data_ptr();
stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
void *attn_bias_ptr = nullptr;
ck_tile::index_t stride_attn_bias = 0;
ck_tile::index_t batch_stride_bias = 0;
ck_tile::index_t nhead_stride_bias = 0;
if (attn_bias_.has_value()) {
auto a_b = attn_bias_.value();
CHECK_DEVICE(a_b);
TORCH_CHECK(a_b.stride(-1) == 1, "attention bias tensor must have contiguous last dimension");
attn_bias_ptr = a_b.data_ptr();
stride_attn_bias = a_b.stride(2);
nhead_stride_bias = a_b.stride(1);
batch_stride_bias = a_b.stride(0);
}
return fmha_fwd_args{q.data_ptr(),
k.data_ptr(),
v.data_ptr(),
alibi_slopes_ptr, // bias
attn_bias_ptr, // bias
has_dropout_randval ? dropout_randval.data_ptr() : nullptr,
has_lse ? softmax_lse.data_ptr() : nullptr,
out.data_ptr(),
nullptr, // seqstart_q
nullptr, // seqstart_k
nullptr, // seqstart_q
nullptr, // seqstart_k
nullptr,
seqlen_q,
seqlen_k,
b,
seqlen_q, // max_seqlen_q
d, // hdim_q
d, // hdim_v
h, // nhead
h_k, // nhead_k
softmax_scale, // scale_s
1, // scale_p
1, // scale_o
seqlen_q, // max_seqlen_q
d, // hdim_q
d, // hdim_v
h, // nhead
h_k, // nhead_k
softmax_scale, // scale_s
1, // scale_p
1, // scale_o
stride_q,
stride_k,
stride_v,
stride_alibi_slopes,
stride_attn_bias,
stride_randval,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
0, // nhead_stride_bias, FA without bias
nhead_stride_bias, // nhead_stride_bias
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
0, // batch_stride_bias, FA without bias
batch_stride_bias, // batch_stride_bias
batch_stride_randval,
batch_stride_lse,
batch_stride_o,
@ -148,14 +149,14 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads xhead_size
std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout,
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
const bool return_dropout_randval,
std::optional<at::Generator> gen_)
std::optional<at::Generator> gen_,
const std::optional<at::Tensor>& attn_bias_) // batch_size x nheads x seqlen_q x seqlen_k
{
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
@ -189,7 +190,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
if (window_size_right >= seqlen_k) { window_size_right = -1; }
// causal=true is the same as causal=false in this case
if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
if (seqlen_q == 1 && !attn_bias_.has_value()) { is_causal = false; }
mask_info mask;
if (is_causal) {
@ -209,7 +210,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value();
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !attn_bias_.has_value();
const int ngroups = num_heads / num_heads_k;
at::Tensor temp_q = q;
if (seqlenq_ngroups_swapped) {
@ -305,6 +306,12 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
}
std::optional<at::Tensor> attn_bias;
if( attn_bias_.has_value())
{
attn_bias = attn_bias_;
}
if (seqlen_k > 0) {
auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
auto stream = at::cuda::getCurrentHIPStream().stream();
@ -317,7 +324,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
head_size_8x,
has_dropout,
has_lse,
alibi_slopes_.has_value());
attn_bias_.has_value());
auto args =
get_ck_fmha_fwd_args(
@ -333,7 +340,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
q,
k,
v,
alibi_slopes_,
attn_bias,
out,
softmax_lse,
p,

View File

@ -14,16 +14,17 @@ fmha_bwd_traits get_ck_fmha_varlen_bwd_traits(const mask_info &mask,
std::string dtype,
int head_size,
bool has_dropout,
bool enable_alibi,
bool deterministic)
bool enable_bias,
bool deterministic,
bool bias_requires_grad)
{
return fmha_bwd_traits{head_size,
head_size,
dtype,
true, // is_group_mode
mask.type,
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
false, // has_dbias
enable_bias ? bias_enum::elementwise_bias : bias_enum::no_bias,
bias_requires_grad, // has_dbias
has_dropout,
false, // s_randval
deterministic};
@ -43,7 +44,9 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
const at::Tensor v,
const at::Tensor seqlens_q,
const at::Tensor seqlens_k,
std::optional<at::Tensor> &alibi_slopes_,
std::optional<at::Tensor> &attn_bias_,
bool bias_requires_grad,
std::optional<at::Tensor> &grad_bias,
const at::Tensor out,
const at::Tensor softmax_lse,
const at::Tensor dout,
@ -115,23 +118,40 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
float p_undrop = 1.0 - p_dropout;
void *alibi_slopes_ptr = nullptr;
ck_tile::index_t stride_alibi_slopes = 0;
// bias: (batch_size, nheads, seqlen_q, seqlen_k)
void *attn_bias_ptr = nullptr;
ck_tile::index_t nhead_stride_bias = 0;
ck_tile::index_t batch_stride_bias = 0;
ck_tile::index_t stride_attn_bias = 0;
if (attn_bias_.has_value()) {
auto a_b = attn_bias_.value();
CHECK_DEVICE(a_b);
TORCH_CHECK(a_b.stride(-1) == 1, "Attention bias tensor must have contiguous last dimension");
attn_bias_ptr = a_b.data_ptr();
stride_attn_bias = a_b.stride(2);
nhead_stride_bias = a_b.stride(1);
batch_stride_bias = a_b.stride(0);
}
if (alibi_slopes_.has_value()) {
auto alibi_slopes = alibi_slopes_.value();
CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
TORCH_CHECK(alibi_slopes.sizes() == at::IntArrayRef({h}) || alibi_slopes.sizes() == at::IntArrayRef({b, h}));
alibi_slopes_ptr = alibi_slopes.data_ptr();
// alibi_slopes:(batch_size, nheads) or (nhead)
stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
void *dbias_ptr = nullptr;
ck_tile::index_t stride_dbias = 0;
ck_tile::index_t nhead_stride_dbias = 0;
ck_tile::index_t batch_stride_dbias = 0;
// dbias: (batch_size, nheads, seqlen_q, seqlen_k)
if(bias_requires_grad) {
// If bias_requires_grad is true, grad_bias is guaranteed to have a value via line 270
//grad_bias
auto dbias = grad_bias.value();
dbias_ptr = dbias.data_ptr();
stride_dbias = dbias.stride(2);
nhead_stride_dbias = dbias.stride(1);
batch_stride_dbias = dbias.stride(0);
}
return fmha_bwd_args{q.data_ptr(),
k.data_ptr(),
v.data_ptr(),
alibi_slopes_ptr, // bias
attn_bias_ptr, // bias
out.data_ptr(),
softmax_lse.data_ptr(),
dout.data_ptr(),
@ -140,7 +160,7 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
dq.data_ptr(),
dk.data_ptr(),
dv.data_ptr(),
nullptr, // dbias
dbias_ptr, // dbias
dq_acc.data_ptr(), // dq_acc
seqlens_q.data_ptr(), // seqstart_q
seqlens_k.data_ptr(), // seqstart_k
@ -158,7 +178,7 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
stride_q,
stride_k,
stride_v,
stride_alibi_slopes,
stride_attn_bias,
stride_o,
0, // stride_randval
stride_do,
@ -166,11 +186,11 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
stride_dq,
stride_dk,
stride_dv,
0, // stride_dbias, FA without bias
stride_dbias, // stride_dbias
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
0, // nhead_stride_bias, FA without bias
nhead_stride_bias, // nhead_stride_bias
nhead_stride_o,
0, // nhead_stride_randval
nhead_stride_do,
@ -179,11 +199,11 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
nhead_stride_dq,
nhead_stride_dk,
nhead_stride_dv,
0, // nhead_stride_dbias, FA without dbias
nhead_stride_dbias, // nhead_stride_dbias
batch_stride_q,
batch_stride_k,
batch_stride_v,
0 , // batch_stride_bias, FA without bias
batch_stride_bias, // batch_stride_bias
batch_stride_o,
0, // batch_stride_randval
batch_stride_do,
@ -192,7 +212,7 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
batch_stride_dq,
batch_stride_dk,
batch_stride_dv,
0 , // batch_stride_dbias, FA without dbias
batch_stride_dbias, // batch_stride_dbias
split_stride_dq_acc,
mask.left,
mask.right,
@ -202,7 +222,7 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
drop_seed_offset};
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
mha_varlen_bwd_ck(const at::Tensor &dout, // total_q x num_heads x head_size
const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
@ -214,7 +234,9 @@ mha_varlen_bwd_ck(const at::Tensor &dout, // total_q x num_hea
std::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
std::optional<at::Tensor> &attn_bias_, // b x num_heads x seqlen_q x seqlen_k
bool bias_requires_grad,
std::optional<at::Tensor> &grad_bias,
const int max_seqlen_q,
const int max_seqlen_k, // max sequence length to choose the kernel
const float p_dropout, // probability to drop
@ -260,6 +282,9 @@ mha_varlen_bwd_ck(const at::Tensor &dout, // total_q x num_hea
CHECK_CONTIGUOUS(cu_seqlens_q);
CHECK_CONTIGUOUS(cu_seqlens_k);
TORCH_CHECK((bias_requires_grad && grad_bias.has_value()) || (!bias_requires_grad),
"If bias_requires_grad is set, grad_bias must have a value");
const auto sizes = q.sizes();
const int total_q = sizes[0];
@ -381,7 +406,13 @@ mha_varlen_bwd_ck(const at::Tensor &dout, // total_q x num_hea
ck_tile::stream_config stream_config{stream};
dq.zero_(); // ck use atomic operation on dq
auto traits =
get_ck_fmha_varlen_bwd_traits(mask, q_dtype_str, head_size_8x, is_dropout, alibi_slopes_.has_value(), deterministic);
get_ck_fmha_varlen_bwd_traits(mask,
q_dtype_str,
head_size_8x,
is_dropout,
attn_bias_.has_value(),
deterministic,
bias_requires_grad);
auto args =
get_ck_fmha_varlen_bwd_args(
@ -397,7 +428,9 @@ mha_varlen_bwd_ck(const at::Tensor &dout, // total_q x num_hea
v,
cu_seqlens_q,
cu_seqlens_k,
alibi_slopes_,
attn_bias_,
bias_requires_grad,
grad_bias,
out,
softmax_lse,
dout_padded,
@ -428,7 +461,14 @@ mha_varlen_bwd_ck(const at::Tensor &dout, // total_q x num_hea
dk = dk.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)});
dv = dv.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)});
}
at::Tensor dbias;
if(bias_requires_grad) {
dbias = grad_bias.value();
} else {
dbias = at::empty({batch_size, num_heads, max_seqlen_q, max_seqlen_k}, q.options());
}
return { dq, dk, dv, softmax_d };
return { dq, dk, dv, softmax_d, dbias };
}
} // namespace pytorch_flash

View File

@ -13,7 +13,7 @@ fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask,
int head_size,
bool has_dropout,
bool has_lse,
bool enable_alibi)
bool enable_bias)
{
return fmha_fwd_traits{head_size,
head_size,
@ -21,7 +21,7 @@ fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask,
true, // is_group_mode
true, // is_v_rowmajor
mask.type,
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
enable_bias ? bias_enum::elementwise_bias : bias_enum::no_bias,
has_lse,
has_dropout,
false}; // do_fp8_static_quant
@ -42,11 +42,10 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
const at::Tensor v,
const at::Tensor seqlens_q,
const at::Tensor seqlens_k,
std::optional<at::Tensor> &alibi_slopes_,
std::optional<at::Tensor> &attn_bias_,
at::Tensor out,
at::Tensor softmax_lse,
at::Tensor dropout_randval,
float softmax_scale,
float p_dropout,
std::pair<uint64_t*, uint64_t*> drop_seed_offset)
@ -56,7 +55,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
// v: (total_k, nheads_k, d)
// o: (total_q, nheads, d)
// alibi_slopes:(batch, nheads) or (nhead)
// attn_bias :(batch, nheads, max_seqlen_q, max_seqlen_k)
// lse: (batch, nheads, max_seqlen_q)
// randval: (nheads, total_q, max_seqlen_k)
@ -84,22 +83,23 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
ck_tile::index_t batch_stride_lse = has_lse ? softmax_lse.stride(0) : 0;
ck_tile::index_t batch_stride_randval = 0;
void *alibi_slopes_ptr = nullptr;
ck_tile::index_t stride_alibi_slopes = 0;
void *attn_bias_ptr = nullptr;
ck_tile::index_t stride_attn_bias = 0;
if (alibi_slopes_.has_value()) {
auto alibi_slopes = alibi_slopes_.value();
CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
TORCH_CHECK(alibi_slopes.sizes() == at::IntArrayRef({h}) || alibi_slopes.sizes() == at::IntArrayRef({b, h}));
alibi_slopes_ptr = alibi_slopes.data_ptr();
stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
if (attn_bias_.has_value()) {
auto a_b = attn_bias_.value();
CHECK_DEVICE(a_b);
TORCH_CHECK(a_b.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
//TORCH_CHECK(alibi_slopes.sizes() == at::IntArrayRef({h}) || alibi_slopes.sizes() == at::IntArrayRef({b, h}));
attn_bias_ptr = a_b.data_ptr();
//stride_attn_bias = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
stride_attn_bias = a_b.stride(0);
}
return fmha_fwd_args{q.data_ptr(),
k.data_ptr(),
v.data_ptr(),
alibi_slopes_ptr, // bias
attn_bias_ptr, // bias
has_dropout_randval ? dropout_randval.data_ptr() : nullptr,
has_lse ? softmax_lse.data_ptr() : nullptr,
out.data_ptr(),
@ -120,7 +120,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
stride_q,
stride_k,
stride_v,
stride_alibi_slopes,
stride_attn_bias,
stride_randval,
stride_o,
nhead_stride_q,
@ -153,7 +153,6 @@ mha_varlen_fwd_ck(const at::Tensor &q, // total_q x num_heads
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
std::optional<at::Tensor> & /*seqused_k*/,
std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
int max_seqlen_q,
const int max_seqlen_k,
const float p_dropout,
@ -163,7 +162,8 @@ mha_varlen_fwd_ck(const at::Tensor &q, // total_q x num_heads
int window_size_left,
int window_size_right,
const bool return_dropout_randval,
std::optional<at::Generator> gen_)
std::optional<at::Generator> gen_,
const std::optional<at::Tensor>& attn_bias_)
{
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
@ -200,7 +200,7 @@ mha_varlen_fwd_ck(const at::Tensor &q, // total_q x num_heads
const int max_num_blocks_per_seq = 0;
const int num_blocks = 0;
if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case
if (max_seqlen_q == 1 && !attn_bias_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case
// TODO
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
@ -307,6 +307,13 @@ mha_varlen_fwd_ck(const at::Tensor &q, // total_q x num_heads
flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), philox_args, rng_state_ptr);
}
// remove const from attn_bias_
std::optional<at::Tensor> attn_bias;
if( attn_bias_.has_value())
{
attn_bias = attn_bias_;
}
if (max_seqlen_k > 0) {
auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
@ -314,7 +321,7 @@ mha_varlen_fwd_ck(const at::Tensor &q, // total_q x num_heads
ck_tile::stream_config stream_config{stream};
auto traits =
get_ck_fmha_varlen_fwd_traits(mask, q_dtype_str, head_size_8x, has_dropout, has_lse, alibi_slopes_.has_value());
get_ck_fmha_varlen_fwd_traits(mask, q_dtype_str, head_size_8x, has_dropout, has_lse, attn_bias_.has_value());
auto args =
get_ck_fmha_varlen_fwd_args(
@ -331,7 +338,7 @@ mha_varlen_fwd_ck(const at::Tensor &q, // total_q x num_heads
v_padded,
cu_seqlens_q,
cu_seqlens_k,
alibi_slopes_,
attn_bias,
out,
softmax_lse,
p,

View File

@ -1,11 +0,0 @@
#!/bin/bash
set -ex
file_renaming_txt="rename_ck_autogen_files.output.txt"
rm -rf $file_renaming_txt
for file in `ls fmha_*wd*hip`; do
sha1=$(sha1sum $file | cut -d' ' -f1)
new_file="fmha_ck_autogen_${sha1}.hip"
mv $file $new_file
echo "$file -> $new_file" >> $file_renaming_txt
done

View File

@ -143,15 +143,14 @@ mha_fwd_ck(
const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size
std::optional<at::Tensor>&
out_, // batch_size x seqlen_q x num_heads x head_size
std::optional<at::Tensor>&
alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout,
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
const bool return_softmax,
std::optional<at::Generator> gen_);
std::optional<at::Generator> gen_,
const std::optional<at::Tensor>& attn_bias_); // batch_size x nheads x seqlen_q x seqlen_k
std::tuple<
at::Tensor,
@ -176,7 +175,6 @@ mha_varlen_fwd_ck(
std::optional<at::Tensor>&
seqused_k, // b. If given, only this many elements of each batch
// element's keys are used.
std::optional<at::Tensor>& alibi_slopes_, // num_heads or b x num_heads
int max_seqlen_q,
const int max_seqlen_k,
const float p_dropout,
@ -186,9 +184,10 @@ mha_varlen_fwd_ck(
int window_size_left,
int window_size_right,
const bool return_softmax,
std::optional<at::Generator> gen_);
std::optional<at::Generator> gen_,
const std::optional<at::Tensor>& attn_bias_);
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd_ck(
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd_ck(
const at::Tensor& dout, // batch_size x seqlen_q x num_heads, x head_size_og
const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size
@ -202,7 +201,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd_ck(
std::optional<at::Tensor>&
dv_, // batch_size x seqlen_k x num_heads_k x head_size
std::optional<at::Tensor>&
alibi_slopes_, // num_heads or batch_size x num_heads
attn_bias_, // batch_size x num_heads x seqlen_q x seqlen_k
bool bias_requires_grad,
std::optional<at::Tensor>& grad_bias,
const float p_dropout, // probability to drop
const float softmax_scale,
const bool is_causal,
@ -212,7 +213,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd_ck(
const at::Tensor philox_seed,
const at::Tensor philox_offset);
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd_ck(
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd_ck(
const at::Tensor& dout, // total_q x num_heads, x head_size
const at::Tensor&
q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
@ -230,7 +231,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd_ck(
dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor& cu_seqlens_q, // b+1
const at::Tensor& cu_seqlens_k, // b+1
std::optional<at::Tensor>& alibi_slopes_, // num_heads or b x num_heads
std::optional<at::Tensor>& attn_bias_, // num_heads or b x num_heads
bool bias_requires_grad,
std::optional<at::Tensor>& grad_bias,
const int max_seqlen_q,
const int max_seqlen_k, // max sequence length to choose the kernel
const float p_dropout, // probability to drop
@ -273,19 +276,20 @@ mha_fwd(
#if defined(USE_CK_FLASH_ATTENTION)
if (at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) {
std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
return mha_fwd_ck(
q,
k,
v,
out_,
alibi_slopes_,
p_dropout,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
return_softmax,
gen_);
gen_,
dummy_attn_bias); // Not used in flash attention
} else {
return mha_fwd_aot(
q,
@ -358,6 +362,7 @@ mha_varlen_fwd(
#if defined(USE_CK_FLASH_ATTENTION)
if (at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) {
std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
return mha_varlen_fwd_ck(
q,
k,
@ -366,7 +371,6 @@ mha_varlen_fwd(
cu_seqlens_q,
cu_seqlens_k,
seqused_k,
alibi_slopes_,
max_seqlen_q,
max_seqlen_k,
p_dropout,
@ -376,7 +380,8 @@ mha_varlen_fwd(
window_size_left,
window_size_right,
return_softmax,
gen_);
gen_,
dummy_attn_bias); // Not used in flash attention
} else {
return mha_varlen_fwd_aot(
q,
@ -450,25 +455,34 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd(
#if defined(USE_CK_FLASH_ATTENTION)
if (at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) {
return mha_bwd_ck(
dout,
q,
k,
v,
out,
softmax_lse,
dq_,
dk_,
dv_,
alibi_slopes_,
p_dropout,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
deterministic,
philox_seed,
philox_offset);
std::optional<at::Tensor> non_null_dbias = std::nullopt;
auto[dQuery,
dKey,
dValue,
dSoftmax,
dBias] = mha_bwd_ck(
dout,
q,
k,
v,
out,
softmax_lse,
dq_,
dk_,
dv_,
alibi_slopes_,
false, // bias_requires_grad
non_null_dbias,
p_dropout,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
deterministic,
philox_seed,
philox_offset);
// for FA return [dQ, dV, dK, dSoftmax]
return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax));
} else {
return mha_bwd_aot(
dout,
@ -551,30 +565,39 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd
#if defined(USE_CK_FLASH_ATTENTION)
if (at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) {
return mha_varlen_bwd_ck(
dout,
q,
k,
v,
out,
softmax_lse,
dq_,
dk_,
dv_,
cu_seqlens_q,
cu_seqlens_k,
alibi_slopes_,
max_seqlen_q,
max_seqlen_k,
p_dropout,
softmax_scale,
zero_tensors,
is_causal,
window_size_left,
window_size_right,
deterministic,
philox_seed,
philox_offset);
std::optional<at::Tensor> non_null_dbias = std::nullopt;
auto[dQuery,
dKey,
dValue,
dSoftmax,
dBias] = mha_varlen_bwd_ck(
dout,
q,
k,
v,
out,
softmax_lse,
dq_,
dk_,
dv_,
cu_seqlens_q,
cu_seqlens_k,
alibi_slopes_,
false, // bias_requires_grad
non_null_dbias,
max_seqlen_q,
max_seqlen_k,
p_dropout,
softmax_scale,
zero_tensors,
is_causal,
window_size_left,
window_size_right,
deterministic,
philox_seed,
philox_offset);
// for FA return [dQ, dV, dK, dSoftmax]
return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax));
} else {
return mha_varlen_bwd_aot(
dout,