mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Revert "[ROCm] Bump AOTriton to 0.10b (#156290)"
This reverts commit 34d8e64ef6.
Reverted https://github.com/pytorch/pytorch/pull/156290 on behalf of https://github.com/atalman due to failing multiple internal tests ([comment](https://github.com/pytorch/pytorch/pull/156290#issuecomment-2992072727))
This commit is contained in:
parent
b4442f42a9
commit
1036f6d114
|
|
@ -1113,10 +1113,8 @@ _flash_attention_forward(
|
||||||
std::optional<Tensor> alibi_slopes = _alibi_slopes;
|
std::optional<Tensor> alibi_slopes = _alibi_slopes;
|
||||||
const float softcap = 0.0;
|
const float softcap = 0.0;
|
||||||
|
|
||||||
#ifndef USE_ROCM // ROCM backend accepts std::optional for window_size_left/right directly.
|
const int non_null_window_left = window_size_left.has_value() ? window_size_left.value() : -1;
|
||||||
const int non_null_window_left = window_size_left.value_or(-1);
|
const int non_null_window_right = window_size_right.has_value() ? window_size_right.value() : -1;
|
||||||
const int non_null_window_right = window_size_right.value_or(-1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// We are going to have two paths:
|
// We are going to have two paths:
|
||||||
// 1. The standard MHA path for dense tensors
|
// 1. The standard MHA path for dense tensors
|
||||||
|
|
@ -1153,13 +1151,8 @@ _flash_attention_forward(
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
false /*zero_tensors*/,
|
false /*zero_tensors*/,
|
||||||
is_causal,
|
is_causal,
|
||||||
#ifdef USE_ROCM
|
|
||||||
window_size_left,
|
|
||||||
window_size_right,
|
|
||||||
#else
|
|
||||||
non_null_window_left,
|
non_null_window_left,
|
||||||
non_null_window_right,
|
non_null_window_right,
|
||||||
#endif
|
|
||||||
softcap,
|
softcap,
|
||||||
return_debug_mask,
|
return_debug_mask,
|
||||||
std::nullopt /*gen_*/);
|
std::nullopt /*gen_*/);
|
||||||
|
|
@ -1182,13 +1175,8 @@ _flash_attention_forward(
|
||||||
dropout_p,
|
dropout_p,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
is_causal,
|
is_causal,
|
||||||
#ifdef USE_ROCM
|
|
||||||
window_size_left,
|
|
||||||
window_size_right,
|
|
||||||
#else
|
|
||||||
non_null_window_left,
|
non_null_window_left,
|
||||||
non_null_window_right,
|
non_null_window_right,
|
||||||
#endif
|
|
||||||
softcap,
|
softcap,
|
||||||
return_debug_mask, /*return_softmax (this is used for testing)*/
|
return_debug_mask, /*return_softmax (this is used for testing)*/
|
||||||
std::nullopt);
|
std::nullopt);
|
||||||
|
|
|
||||||
|
|
@ -87,10 +87,8 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
|
||||||
auto contiguous_grad_out = grad_out.contiguous();
|
auto contiguous_grad_out = grad_out.contiguous();
|
||||||
auto contiguous_out = out.contiguous();
|
auto contiguous_out = out.contiguous();
|
||||||
|
|
||||||
#ifndef USE_ROCM // ROCM backend accepts std::optional for window_size_left/right directly.
|
|
||||||
const int non_null_window_left = window_size_left.has_value() ? window_size_left.value() : -1;
|
const int non_null_window_left = window_size_left.has_value() ? window_size_left.value() : -1;
|
||||||
const int non_null_window_right = window_size_right.has_value() ? window_size_right.value() : -1;
|
const int non_null_window_right = window_size_right.has_value() ? window_size_right.value() : -1;
|
||||||
#endif
|
|
||||||
|
|
||||||
std::optional<at::Tensor> dq{std::nullopt};
|
std::optional<at::Tensor> dq{std::nullopt};
|
||||||
std::optional<at::Tensor> dk{std::nullopt};
|
std::optional<at::Tensor> dk{std::nullopt};
|
||||||
|
|
@ -138,13 +136,8 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
false /*zero_tensors*/,
|
false /*zero_tensors*/,
|
||||||
is_causal,
|
is_causal,
|
||||||
#ifdef USE_ROCM
|
|
||||||
window_size_left,
|
|
||||||
window_size_right,
|
|
||||||
#else
|
|
||||||
non_null_window_left,
|
non_null_window_left,
|
||||||
non_null_window_right,
|
non_null_window_right,
|
||||||
#endif
|
|
||||||
softcap,
|
softcap,
|
||||||
determinisitic,
|
determinisitic,
|
||||||
philox_seed,
|
philox_seed,
|
||||||
|
|
@ -166,13 +159,8 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
|
||||||
dropout_p,
|
dropout_p,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
is_causal,
|
is_causal,
|
||||||
#ifdef USE_ROCM
|
|
||||||
window_size_left,
|
|
||||||
window_size_right,
|
|
||||||
#else
|
|
||||||
non_null_window_left,
|
non_null_window_left,
|
||||||
non_null_window_right,
|
non_null_window_right,
|
||||||
#endif
|
|
||||||
softcap,
|
softcap,
|
||||||
determinisitic,
|
determinisitic,
|
||||||
philox_seed,
|
philox_seed,
|
||||||
|
|
|
||||||
|
|
@ -64,14 +64,8 @@
|
||||||
#include <aotriton/flash.h>
|
#include <aotriton/flash.h>
|
||||||
#include <aotriton/runtime.h>
|
#include <aotriton/runtime.h>
|
||||||
|
|
||||||
#if AOTRITON_VERSION_MINOR < 9
|
#if AOTRITON_VERSION_MINOR != 9
|
||||||
#error "This adaptor code is only tested with AOTriton >= 0.9"
|
#error "This adaptor code is only tested with AOTriton 0.9.x"
|
||||||
#endif
|
|
||||||
|
|
||||||
#if (AOTRITON_VERSION_MAJOR * 100 + AOTRITON_VERSION_MINOR) >= 10
|
|
||||||
#define V3_API 1
|
|
||||||
#else
|
|
||||||
#define V3_API 0
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace pytorch_flash {
|
namespace pytorch_flash {
|
||||||
|
|
@ -87,38 +81,6 @@ void check_gpu_arch(hipStream_t stream) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<bool, int, int>
|
|
||||||
calculate_swa(std::optional<int64_t> window_size_left,
|
|
||||||
std::optional<int64_t> window_size_right,
|
|
||||||
int max_seqlen_q,
|
|
||||||
int max_seqlen_k,
|
|
||||||
bool is_causal) {
|
|
||||||
#if V3_API // SWA is exposed through V3 API
|
|
||||||
bool needs_swa = false;
|
|
||||||
using aotriton::v3::flash::WindowValue;
|
|
||||||
// Default values when std::optional window_size_left/right have no value
|
|
||||||
int window_left = max_seqlen_q;
|
|
||||||
int window_right = max_seqlen_k;
|
|
||||||
if (is_causal) {
|
|
||||||
window_left = WindowValue::TopLeftAligned;
|
|
||||||
window_right = WindowValue::TopLeftAligned;
|
|
||||||
}
|
|
||||||
if (window_size_left.has_value() || window_size_right.has_value()) {
|
|
||||||
needs_swa = true;
|
|
||||||
window_left = window_size_left.value_or(window_left);
|
|
||||||
window_right = window_size_right.value_or(window_right);
|
|
||||||
}
|
|
||||||
return std::make_tuple(needs_swa, window_left, window_right);
|
|
||||||
#else
|
|
||||||
if (window_size_left.has_value() || window_size_right.has_value()) {
|
|
||||||
TORCH_WARN_ONCE("Current AOTriton does not support sliding window attention (SWA)."
|
|
||||||
" Both window_size_left and window_size_right will be ignored."
|
|
||||||
" Re-compile PyTorch with AOTriton >= 0.10b to enable SWA support.");
|
|
||||||
}
|
|
||||||
return std::make_tuple(false, 0, 0);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
// We want to checkpoint and save the RNG state for backward if dropout
|
// We want to checkpoint and save the RNG state for backward if dropout
|
||||||
// We get the default generator and return the seed and offset which will
|
// We get the default generator and return the seed and offset which will
|
||||||
// be used in the backward function
|
// be used in the backward function
|
||||||
|
|
@ -165,8 +127,8 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x
|
||||||
const float p_dropout,
|
const float p_dropout,
|
||||||
const float softmax_scale,
|
const float softmax_scale,
|
||||||
bool is_causal,
|
bool is_causal,
|
||||||
std::optional<int64_t> window_size_left,
|
int window_size_left,
|
||||||
std::optional<int64_t> window_size_right,
|
int window_size_right,
|
||||||
const bool return_softmax,
|
const bool return_softmax,
|
||||||
const std::optional<at::Generator>& gen_) {
|
const std::optional<at::Generator>& gen_) {
|
||||||
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
||||||
|
|
@ -199,6 +161,7 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x
|
||||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||||
|
|
||||||
if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case
|
if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case
|
||||||
|
if (is_causal) { window_size_right = 0; }
|
||||||
|
|
||||||
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
|
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
|
||||||
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
|
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
|
||||||
|
|
@ -249,19 +212,6 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x
|
||||||
atomic_counter = at::zeros({1}, opts.dtype(at::kInt));
|
atomic_counter = at::zeros({1}, opts.dtype(at::kInt));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto [needs_swa, window_left, window_right] = calculate_swa(window_size_left,
|
|
||||||
window_size_right,
|
|
||||||
seqlen_q,
|
|
||||||
seqlen_k,
|
|
||||||
is_causal);
|
|
||||||
#if V3_API
|
|
||||||
const bool uses_swa = needs_swa;
|
|
||||||
#else
|
|
||||||
// When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be
|
|
||||||
// optimized out (hopefully).
|
|
||||||
constexpr bool uses_swa = false;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
hipError_t err; // TODO: Error handling
|
hipError_t err; // TODO: Error handling
|
||||||
using aotriton::v2::flash::attn_fwd;
|
using aotriton::v2::flash::attn_fwd;
|
||||||
using sdp::aotriton_adapter::mk_aotensor;
|
using sdp::aotriton_adapter::mk_aotensor;
|
||||||
|
|
@ -276,36 +226,6 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x
|
||||||
auto seed_output = mk_philoxtensor(use_philox_state ? seed_t.data_ptr<int64_t>() : nullptr);
|
auto seed_output = mk_philoxtensor(use_philox_state ? seed_t.data_ptr<int64_t>() : nullptr);
|
||||||
auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr<int64_t>() : nullptr);
|
auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr<int64_t>() : nullptr);
|
||||||
auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr<int32_t>() : nullptr);
|
auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr<int32_t>() : nullptr);
|
||||||
if (uses_swa) {
|
|
||||||
#if V3_API
|
|
||||||
using aotriton::v3::flash::CausalType;
|
|
||||||
using aotriton::v3::flash::VarlenType;
|
|
||||||
aotriton::v3::flash::attn_fwd_params params;
|
|
||||||
params.Q = mk_aotensor(q_t, "q");
|
|
||||||
params.K = mk_aotensor(k_t, "k");
|
|
||||||
params.V = mk_aotensor(v_t, "v");
|
|
||||||
params.Sm_scale = softmax_scale;
|
|
||||||
params.L = mk_aotensor<2>(M, "M");
|
|
||||||
params.Out = mk_aotensor(output_t, "Out");
|
|
||||||
params.Max_seqlen_q = seqlen_q; // Unused if cu_seqlens_q is empty
|
|
||||||
params.Max_seqlen_k = seqlen_k; // Unused if cu_seqlens_k is empty
|
|
||||||
params.dropout_p = p_dropout;
|
|
||||||
params.philox_seed_ptr = seed;
|
|
||||||
params.philox_offset1 = offset1;
|
|
||||||
params.philox_offset2 = offset2;
|
|
||||||
params.philox_seed_output = seed_output;
|
|
||||||
params.philox_offset_output = offset_output;
|
|
||||||
params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax");
|
|
||||||
params.persistent_atomic_counter = persistent_counter;
|
|
||||||
params.causal_type = CausalType::WindowedAttention;
|
|
||||||
params.varlen_type = VarlenType::None;
|
|
||||||
params.window_left = window_left;
|
|
||||||
params.window_right = window_right;
|
|
||||||
err = aotriton::v3::flash::attn_fwd(params,
|
|
||||||
aotriton::v3::flash::attn_fwd_params::kVersion,
|
|
||||||
stream);
|
|
||||||
#endif
|
|
||||||
} else {
|
|
||||||
err = attn_fwd(mk_aotensor(q_t, "q"),
|
err = attn_fwd(mk_aotensor(q_t, "q"),
|
||||||
mk_aotensor(k_t, "k"),
|
mk_aotensor(k_t, "k"),
|
||||||
mk_aotensor(v_t, "v"),
|
mk_aotensor(v_t, "v"),
|
||||||
|
|
@ -323,7 +243,6 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x
|
||||||
is_causal,
|
is_causal,
|
||||||
persistent_counter,
|
persistent_counter,
|
||||||
stream);
|
stream);
|
||||||
}
|
|
||||||
|
|
||||||
return {out, q_padded, k_padded, v_padded, M.view({batch_size, num_heads, seqlen_q}), seed_t, offset_t, softmax_fa_t};
|
return {out, q_padded, k_padded, v_padded, M.view({batch_size, num_heads, seqlen_q}), seed_t, offset_t, softmax_fa_t};
|
||||||
}
|
}
|
||||||
|
|
@ -344,8 +263,8 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
|
||||||
const float softmax_scale,
|
const float softmax_scale,
|
||||||
const bool zero_tensors,
|
const bool zero_tensors,
|
||||||
bool is_causal,
|
bool is_causal,
|
||||||
std::optional<int64_t> window_size_left,
|
int window_size_left,
|
||||||
std::optional<int64_t> window_size_right,
|
int window_size_right,
|
||||||
const bool return_softmax,
|
const bool return_softmax,
|
||||||
const std::optional<at::Generator>& gen_) {
|
const std::optional<at::Generator>& gen_) {
|
||||||
TORCH_CHECK(!seqused_k.has_value(), "[ROCm] mha_varlen_fwd: seqused_k must be nullopt");
|
TORCH_CHECK(!seqused_k.has_value(), "[ROCm] mha_varlen_fwd: seqused_k must be nullopt");
|
||||||
|
|
@ -393,6 +312,13 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
|
||||||
TORCH_CHECK(head_size_og <= 512, "FlashAttention on ROCm forward only supports head dimension at most 512");
|
TORCH_CHECK(head_size_og <= 512, "FlashAttention on ROCm forward only supports head dimension at most 512");
|
||||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||||
|
|
||||||
|
if (window_size_left >= max_seqlen_k) {
|
||||||
|
window_size_left = -1;
|
||||||
|
}
|
||||||
|
if (window_size_right >= max_seqlen_k) {
|
||||||
|
window_size_right = -1;
|
||||||
|
}
|
||||||
|
|
||||||
CHECK_SHAPE(temp_q, total_q, num_heads, head_size_og);
|
CHECK_SHAPE(temp_q, total_q, num_heads, head_size_og);
|
||||||
const int total_k = k.size(0);
|
const int total_k = k.size(0);
|
||||||
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
|
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
|
||||||
|
|
@ -442,19 +368,6 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto [needs_swa, window_left, window_right] = calculate_swa(window_size_left,
|
|
||||||
window_size_right,
|
|
||||||
max_seqlen_q,
|
|
||||||
max_seqlen_k,
|
|
||||||
is_causal);
|
|
||||||
#if V3_API
|
|
||||||
const bool uses_swa = needs_swa;
|
|
||||||
#else
|
|
||||||
// When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be
|
|
||||||
// optimized out (hopefully).
|
|
||||||
constexpr bool uses_swa = false;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto [seed_t, offset_t, philox_state, use_philox_state] =
|
auto [seed_t, offset_t, philox_state, use_philox_state] =
|
||||||
prepare_philox_arguments(p_dropout, batch_size * num_heads * 32);
|
prepare_philox_arguments(p_dropout, batch_size * num_heads * 32);
|
||||||
|
|
||||||
|
|
@ -477,36 +390,6 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
|
||||||
auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr<int64_t>()) : nullscalar;
|
auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr<int64_t>()) : nullscalar;
|
||||||
auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr<int64_t>()) : nullscalar;
|
auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr<int64_t>()) : nullscalar;
|
||||||
auto persistent_counter = is_causal ? mk_philoxtensor(atomic_counter.data_ptr<int64_t>()) : nullscalar;
|
auto persistent_counter = is_causal ? mk_philoxtensor(atomic_counter.data_ptr<int64_t>()) : nullscalar;
|
||||||
if (uses_swa) {
|
|
||||||
using aotriton::v3::flash::CausalType;
|
|
||||||
using aotriton::v3::flash::VarlenType;
|
|
||||||
aotriton::v3::flash::attn_fwd_params params;
|
|
||||||
params.Q = mk_aotensor(q_padded, "q");
|
|
||||||
params.K = mk_aotensor(k_padded, "k");
|
|
||||||
params.V = mk_aotensor(v_padded, "v");
|
|
||||||
params.Sm_scale = softmax_scale;
|
|
||||||
params.L = mk_aotensor<2>(M, "M");
|
|
||||||
params.Out = mk_aotensor(out_padded, "Out");
|
|
||||||
params.cu_seqlens_q = mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q");
|
|
||||||
params.cu_seqlens_k = mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k");
|
|
||||||
params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty
|
|
||||||
params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty
|
|
||||||
params.dropout_p = p_dropout;
|
|
||||||
params.philox_seed_ptr = seed;
|
|
||||||
params.philox_offset1 = offset1;
|
|
||||||
params.philox_offset2 = offset2;
|
|
||||||
params.philox_seed_output = seed_output;
|
|
||||||
params.philox_offset_output = offset_output;
|
|
||||||
params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax");
|
|
||||||
params.persistent_atomic_counter = persistent_counter;
|
|
||||||
params.causal_type = CausalType::WindowedAttention;
|
|
||||||
params.varlen_type = VarlenType::CompactVarlen;
|
|
||||||
params.window_left = window_left;
|
|
||||||
params.window_right = window_right;
|
|
||||||
err = aotriton::v3::flash::attn_fwd(params,
|
|
||||||
aotriton::v3::flash::attn_fwd_params::kVersion,
|
|
||||||
stream);
|
|
||||||
} else {
|
|
||||||
err = attn_fwd_compact_varlen(mk_aotensor(q_padded, "q"),
|
err = attn_fwd_compact_varlen(mk_aotensor(q_padded, "q"),
|
||||||
mk_aotensor(k_padded, "k"),
|
mk_aotensor(k_padded, "k"),
|
||||||
mk_aotensor(v_padded, "v"),
|
mk_aotensor(v_padded, "v"),
|
||||||
|
|
@ -528,7 +411,6 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
|
||||||
is_causal,
|
is_causal,
|
||||||
persistent_counter,
|
persistent_counter,
|
||||||
stream);
|
stream);
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
|
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
|
||||||
out.zero_();
|
out.zero_();
|
||||||
|
|
@ -552,8 +434,8 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea
|
||||||
const float p_dropout, // probability to drop
|
const float p_dropout, // probability to drop
|
||||||
const float softmax_scale,
|
const float softmax_scale,
|
||||||
const bool is_causal,
|
const bool is_causal,
|
||||||
std::optional<int64_t> window_size_left,
|
int window_size_left,
|
||||||
std::optional<int64_t> window_size_right,
|
int window_size_right,
|
||||||
const bool deterministic,
|
const bool deterministic,
|
||||||
const at::Tensor& philox_seed,
|
const at::Tensor& philox_seed,
|
||||||
const at::Tensor& philox_offset) {
|
const at::Tensor& philox_offset) {
|
||||||
|
|
@ -642,19 +524,6 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea
|
||||||
dv = at::empty_like(k);
|
dv = at::empty_like(k);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto [needs_swa, window_left, window_right] = calculate_swa(window_size_left,
|
|
||||||
window_size_right,
|
|
||||||
seqlen_q,
|
|
||||||
seqlen_k,
|
|
||||||
is_causal);
|
|
||||||
#if V3_API
|
|
||||||
const bool uses_swa = needs_swa;
|
|
||||||
#else
|
|
||||||
// When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be
|
|
||||||
// optimized out (hopefully).
|
|
||||||
constexpr bool uses_swa = false;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto opts = q.options();
|
auto opts = q.options();
|
||||||
auto softmax_d = at::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
auto softmax_d = at::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
||||||
|
|
||||||
|
|
@ -672,40 +541,10 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea
|
||||||
int d_head = head_size_og;
|
int d_head = head_size_og;
|
||||||
bool use_fused_bwd = d_head <= 192 && d_head * seqlen_q < 64 * 512;
|
bool use_fused_bwd = d_head <= 192 && d_head * seqlen_q < 64 * 512;
|
||||||
hipError_t err; // TODO: Error handling
|
hipError_t err; // TODO: Error handling
|
||||||
|
if (use_fused_bwd) {
|
||||||
|
using aotriton::v2::flash::attn_bwd_fused;
|
||||||
using sdp::aotriton_adapter::mk_aotensor;
|
using sdp::aotriton_adapter::mk_aotensor;
|
||||||
using sdp::aotriton_adapter::mk_aoscalartensor;
|
using sdp::aotriton_adapter::mk_aoscalartensor;
|
||||||
if (uses_swa) {
|
|
||||||
// Fused BWD does not support SWA
|
|
||||||
at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous();
|
|
||||||
using aotriton::v3::flash::CausalType;
|
|
||||||
using aotriton::v3::flash::VarlenType;
|
|
||||||
aotriton::v3::flash::attn_bwd_params params;
|
|
||||||
params.Q = mk_aotensor(q_t, "q");
|
|
||||||
params.K = mk_aotensor(k_t, "k");
|
|
||||||
params.V = mk_aotensor(v_t, "v");
|
|
||||||
params.Sm_scale = softmax_scale;
|
|
||||||
params.Out = mk_aotensor(out_t, "out");
|
|
||||||
params.DO = mk_aotensor(dout_t, "dout");
|
|
||||||
params.DK = mk_aotensor(dq_t, "dq");
|
|
||||||
params.DV = mk_aotensor(dk_t, "dk");
|
|
||||||
params.DQ = mk_aotensor(dv_t, "dv");
|
|
||||||
params.L = mk_aotensor<2>(softmax_lse_cont, "L");
|
|
||||||
params.D = mk_aotensor<2>(delta, "delta");
|
|
||||||
params.Max_seqlen_q = seqlen_q; // Unused if cu_seqlens_q is empty
|
|
||||||
params.Max_seqlen_k = seqlen_k; // Unused if cu_seqlens_k is empty
|
|
||||||
params.dropout_p = p_dropout;
|
|
||||||
params.philox_seed_ptr = mk_aoscalartensor(philox_seed);
|
|
||||||
params.philox_offset1 = mk_aoscalartensor(philox_offset);
|
|
||||||
params.philox_offset2 = 0;
|
|
||||||
params.causal_type = CausalType::WindowedAttention;
|
|
||||||
params.varlen_type = VarlenType::None;
|
|
||||||
params.window_left = window_left;
|
|
||||||
params.window_right = window_right;
|
|
||||||
err = aotriton::v3::flash::attn_bwd(params,
|
|
||||||
aotriton::v3::flash::attn_bwd_params::kVersion,
|
|
||||||
stream);
|
|
||||||
} else if (use_fused_bwd) {
|
|
||||||
using aotriton::v2::flash::attn_bwd_fused;
|
|
||||||
using sdp::aotriton_adapter::cast_dtype;
|
using sdp::aotriton_adapter::cast_dtype;
|
||||||
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
|
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
|
||||||
err = attn_bwd_fused(mk_aotensor(q_t, "q"),
|
err = attn_bwd_fused(mk_aotensor(q_t, "q"),
|
||||||
|
|
@ -729,6 +568,8 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea
|
||||||
} else {
|
} else {
|
||||||
at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous();
|
at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous();
|
||||||
using aotriton::v2::flash::attn_bwd;
|
using aotriton::v2::flash::attn_bwd;
|
||||||
|
using sdp::aotriton_adapter::mk_aotensor;
|
||||||
|
using sdp::aotriton_adapter::mk_aoscalartensor;
|
||||||
using sdp::aotriton_adapter::cast_dtype;
|
using sdp::aotriton_adapter::cast_dtype;
|
||||||
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
|
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
|
||||||
err = attn_bwd(mk_aotensor(q_t, "q"),
|
err = attn_bwd(mk_aotensor(q_t, "q"),
|
||||||
|
|
@ -774,14 +615,17 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||||
const float softmax_scale,
|
const float softmax_scale,
|
||||||
const bool zero_tensors,
|
const bool zero_tensors,
|
||||||
const bool is_causal,
|
const bool is_causal,
|
||||||
std::optional<int64_t> window_size_left,
|
int window_size_left,
|
||||||
std::optional<int64_t> window_size_right,
|
int window_size_right,
|
||||||
const bool deterministic,
|
const bool deterministic,
|
||||||
const at::Tensor& philox_seed,
|
const at::Tensor& philox_seed,
|
||||||
const at::Tensor& philox_offset)
|
const at::Tensor& philox_offset)
|
||||||
{
|
{
|
||||||
TORCH_CHECK(!alibi_slopes_.has_value(), "[ROCm] mha_varlen_fwd: alibi_slopes_ must be nullopt");
|
TORCH_CHECK(!alibi_slopes_.has_value(), "[ROCm] mha_varlen_fwd: alibi_slopes_ must be nullopt");
|
||||||
|
|
||||||
|
if (is_causal) {
|
||||||
|
window_size_right = 0;
|
||||||
|
}
|
||||||
// Otherwise the kernel will be launched from cuda:0 device
|
// Otherwise the kernel will be launched from cuda:0 device
|
||||||
// Cast to char to avoid compiler warning about narrowing
|
// Cast to char to avoid compiler warning about narrowing
|
||||||
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
|
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
|
||||||
|
|
@ -825,6 +669,9 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||||
TORCH_CHECK(head_size <= 512, "FlashAttention on ROCm backward only supports head dimension at most 512");
|
TORCH_CHECK(head_size <= 512, "FlashAttention on ROCm backward only supports head dimension at most 512");
|
||||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||||
|
|
||||||
|
if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
|
||||||
|
if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
|
||||||
|
|
||||||
CHECK_SHAPE(q, total_q, num_heads, head_size);
|
CHECK_SHAPE(q, total_q, num_heads, head_size);
|
||||||
CHECK_SHAPE(k, total_k, num_heads_k, head_size);
|
CHECK_SHAPE(k, total_k, num_heads_k, head_size);
|
||||||
CHECK_SHAPE(v, total_k, num_heads_k, head_size);
|
CHECK_SHAPE(v, total_k, num_heads_k, head_size);
|
||||||
|
|
@ -887,19 +734,6 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||||
softmax_d.zero_();
|
softmax_d.zero_();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto [needs_swa, window_left, window_right] = calculate_swa(window_size_left,
|
|
||||||
window_size_right,
|
|
||||||
max_seqlen_q,
|
|
||||||
max_seqlen_k,
|
|
||||||
is_causal);
|
|
||||||
#if V3_API
|
|
||||||
const bool uses_swa = needs_swa;
|
|
||||||
#else
|
|
||||||
// When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be
|
|
||||||
// optimized out (hopefully).
|
|
||||||
constexpr bool uses_swa = false;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
at::PhiloxCudaState philox_args;
|
at::PhiloxCudaState philox_args;
|
||||||
if (is_dropout) {
|
if (is_dropout) {
|
||||||
if (at::cuda::currentStreamCaptureStatus() ==
|
if (at::cuda::currentStreamCaptureStatus() ==
|
||||||
|
|
@ -913,40 +747,9 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||||
}
|
}
|
||||||
if (max_seqlen_q > 0) {
|
if (max_seqlen_q > 0) {
|
||||||
hipError_t err; // TODO: Error handling
|
hipError_t err; // TODO: Error handling
|
||||||
|
using aotriton::v2::flash::attn_bwd_compact_varlen;
|
||||||
using sdp::aotriton_adapter::mk_aotensor;
|
using sdp::aotriton_adapter::mk_aotensor;
|
||||||
using sdp::aotriton_adapter::mk_aoscalartensor;
|
using sdp::aotriton_adapter::mk_aoscalartensor;
|
||||||
if (uses_swa) {
|
|
||||||
using aotriton::v3::flash::CausalType;
|
|
||||||
using aotriton::v3::flash::VarlenType;
|
|
||||||
aotriton::v3::flash::attn_bwd_params params;
|
|
||||||
params.Q = mk_aotensor(q_padded, "q");
|
|
||||||
params.K = mk_aotensor(k_padded, "k");
|
|
||||||
params.V = mk_aotensor(v_padded, "v");
|
|
||||||
params.Sm_scale = softmax_scale;
|
|
||||||
params.Out = mk_aotensor(out_t, "out");
|
|
||||||
params.DO = mk_aotensor(dout_t, "dout");
|
|
||||||
params.DK = mk_aotensor(dq_padded, "dq");
|
|
||||||
params.DV = mk_aotensor(dk_padded, "dk");
|
|
||||||
params.DQ = mk_aotensor(dv_padded, "dv");
|
|
||||||
params.L = mk_aotensor<2>(softmax_lse_cont, "L");
|
|
||||||
params.D = mk_aotensor<2>(delta, "delta");
|
|
||||||
params.cu_seqlens_q = mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q");
|
|
||||||
params.cu_seqlens_k = mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k");
|
|
||||||
params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty
|
|
||||||
params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty
|
|
||||||
params.dropout_p = p_dropout;
|
|
||||||
params.philox_seed_ptr = mk_aoscalartensor(philox_seed);
|
|
||||||
params.philox_offset1 = mk_aoscalartensor(philox_offset);
|
|
||||||
params.philox_offset2 = 0;
|
|
||||||
params.causal_type = CausalType::WindowedAttention;
|
|
||||||
params.varlen_type = VarlenType::CompactVarlen;
|
|
||||||
params.window_left = window_left;
|
|
||||||
params.window_right = window_right;
|
|
||||||
err = aotriton::v3::flash::attn_bwd(params,
|
|
||||||
aotriton::v3::flash::attn_bwd_params::kVersion,
|
|
||||||
stream);
|
|
||||||
} else {
|
|
||||||
using aotriton::v2::flash::attn_bwd_compact_varlen;
|
|
||||||
using sdp::aotriton_adapter::cast_dtype;
|
using sdp::aotriton_adapter::cast_dtype;
|
||||||
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
|
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
|
||||||
err = attn_bwd_compact_varlen(mk_aotensor(q_padded, "q"),
|
err = attn_bwd_compact_varlen(mk_aotensor(q_padded, "q"),
|
||||||
|
|
@ -972,7 +775,6 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||||
0,
|
0,
|
||||||
is_causal,
|
is_causal,
|
||||||
stream);
|
stream);
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
|
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
|
||||||
dq.zero_();
|
dq.zero_();
|
||||||
|
|
|
||||||
|
|
@ -51,8 +51,8 @@ mha_fwd_aot(
|
||||||
const float p_dropout,
|
const float p_dropout,
|
||||||
const float softmax_scale,
|
const float softmax_scale,
|
||||||
bool is_causal,
|
bool is_causal,
|
||||||
std::optional<int64_t> window_size_left,
|
int window_size_left,
|
||||||
std::optional<int64_t> window_size_right,
|
int window_size_right,
|
||||||
const bool return_softmax,
|
const bool return_softmax,
|
||||||
const std::optional<at::Generator>& gen_);
|
const std::optional<at::Generator>& gen_);
|
||||||
|
|
||||||
|
|
@ -87,8 +87,8 @@ mha_varlen_fwd_aot(
|
||||||
const float softmax_scale,
|
const float softmax_scale,
|
||||||
const bool zero_tensors,
|
const bool zero_tensors,
|
||||||
bool is_causal,
|
bool is_causal,
|
||||||
std::optional<int64_t> window_size_left,
|
int window_size_left,
|
||||||
std::optional<int64_t> window_size_right,
|
int window_size_right,
|
||||||
const bool return_softmax,
|
const bool return_softmax,
|
||||||
const std::optional<at::Generator>& gen_);
|
const std::optional<at::Generator>& gen_);
|
||||||
|
|
||||||
|
|
@ -110,8 +110,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd_aot(
|
||||||
const float p_dropout, // probability to drop
|
const float p_dropout, // probability to drop
|
||||||
const float softmax_scale,
|
const float softmax_scale,
|
||||||
const bool is_causal,
|
const bool is_causal,
|
||||||
std::optional<int64_t> window_size_left,
|
int window_size_left,
|
||||||
std::optional<int64_t> window_size_right,
|
int window_size_right,
|
||||||
const bool deterministic,
|
const bool deterministic,
|
||||||
const at::Tensor& philox_seed,
|
const at::Tensor& philox_seed,
|
||||||
const at::Tensor& philox_offset);
|
const at::Tensor& philox_offset);
|
||||||
|
|
@ -141,8 +141,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd_aot(
|
||||||
const float softmax_scale,
|
const float softmax_scale,
|
||||||
const bool zero_tensors,
|
const bool zero_tensors,
|
||||||
const bool is_causal,
|
const bool is_causal,
|
||||||
std::optional<int64_t> window_size_left,
|
int window_size_left,
|
||||||
std::optional<int64_t> window_size_right,
|
int window_size_right,
|
||||||
const bool deterministic,
|
const bool deterministic,
|
||||||
const at::Tensor& philox_seed,
|
const at::Tensor& philox_seed,
|
||||||
const at::Tensor& philox_offset);
|
const at::Tensor& philox_offset);
|
||||||
|
|
@ -290,16 +290,14 @@ mha_fwd(
|
||||||
const float p_dropout,
|
const float p_dropout,
|
||||||
const float softmax_scale,
|
const float softmax_scale,
|
||||||
bool is_causal,
|
bool is_causal,
|
||||||
std::optional<int64_t> window_size_left,
|
int window_size_left,
|
||||||
std::optional<int64_t> window_size_right,
|
int window_size_right,
|
||||||
const float softcap,
|
const float softcap,
|
||||||
const bool return_softmax,
|
const bool return_softmax,
|
||||||
std::optional<at::Generator> gen_) {
|
std::optional<at::Generator> gen_) {
|
||||||
#if defined(USE_CK_FLASH_ATTENTION)
|
#if defined(USE_CK_FLASH_ATTENTION)
|
||||||
if (at::globalContext().getROCmFAPreferredBackend() ==
|
if (at::globalContext().getROCmFAPreferredBackend() ==
|
||||||
at::ROCmFABackend::Ck) {
|
at::ROCmFABackend::Ck) {
|
||||||
const int non_null_window_left = window_size_left.value_or(-1);
|
|
||||||
const int non_null_window_right = window_size_right.value_or(-1);
|
|
||||||
std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
|
std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
|
||||||
return mha_fwd_ck(
|
return mha_fwd_ck(
|
||||||
q,
|
q,
|
||||||
|
|
@ -309,13 +307,12 @@ mha_fwd(
|
||||||
p_dropout,
|
p_dropout,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
is_causal,
|
is_causal,
|
||||||
non_null_window_left,
|
window_size_left,
|
||||||
non_null_window_right,
|
window_size_right,
|
||||||
return_softmax,
|
return_softmax,
|
||||||
gen_,
|
gen_,
|
||||||
dummy_attn_bias); // Not used in flash attention
|
dummy_attn_bias); // Not used in flash attention
|
||||||
}
|
} else {
|
||||||
#endif
|
|
||||||
return mha_fwd_aot(
|
return mha_fwd_aot(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
|
|
@ -330,6 +327,22 @@ mha_fwd(
|
||||||
return_softmax,
|
return_softmax,
|
||||||
gen_);
|
gen_);
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
return mha_fwd_aot(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out_,
|
||||||
|
alibi_slopes_,
|
||||||
|
p_dropout,
|
||||||
|
softmax_scale,
|
||||||
|
is_causal,
|
||||||
|
window_size_left,
|
||||||
|
window_size_right,
|
||||||
|
return_softmax,
|
||||||
|
gen_);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
inline std::tuple<
|
inline std::tuple<
|
||||||
at::Tensor,
|
at::Tensor,
|
||||||
|
|
@ -363,8 +376,8 @@ mha_varlen_fwd(
|
||||||
const float softmax_scale,
|
const float softmax_scale,
|
||||||
const bool zero_tensors,
|
const bool zero_tensors,
|
||||||
bool is_causal,
|
bool is_causal,
|
||||||
std::optional<int64_t> window_size_left,
|
int window_size_left,
|
||||||
std::optional<int64_t> window_size_right,
|
int window_size_right,
|
||||||
const float softcap,
|
const float softcap,
|
||||||
const bool return_softmax,
|
const bool return_softmax,
|
||||||
std::optional<at::Generator> gen_) {
|
std::optional<at::Generator> gen_) {
|
||||||
|
|
@ -372,8 +385,6 @@ mha_varlen_fwd(
|
||||||
if (at::globalContext().getROCmFAPreferredBackend() ==
|
if (at::globalContext().getROCmFAPreferredBackend() ==
|
||||||
at::ROCmFABackend::Ck) {
|
at::ROCmFABackend::Ck) {
|
||||||
std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
|
std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
|
||||||
const int non_null_window_left = window_size_left.value_or(-1);
|
|
||||||
const int non_null_window_right = window_size_right.value_or(-1);
|
|
||||||
return mha_varlen_fwd_ck(
|
return mha_varlen_fwd_ck(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
|
|
@ -388,13 +399,12 @@ mha_varlen_fwd(
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
zero_tensors,
|
zero_tensors,
|
||||||
is_causal,
|
is_causal,
|
||||||
non_null_window_left,
|
window_size_left,
|
||||||
non_null_window_right,
|
window_size_right,
|
||||||
return_softmax,
|
return_softmax,
|
||||||
gen_,
|
gen_,
|
||||||
dummy_attn_bias); // Not used in flash attention
|
dummy_attn_bias); // Not used in flash attention
|
||||||
}
|
} else {
|
||||||
#endif
|
|
||||||
return mha_varlen_fwd_aot(
|
return mha_varlen_fwd_aot(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
|
|
@ -416,6 +426,29 @@ mha_varlen_fwd(
|
||||||
return_softmax,
|
return_softmax,
|
||||||
gen_);
|
gen_);
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
return mha_varlen_fwd_aot(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out_,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
seqused_k,
|
||||||
|
block_table_,
|
||||||
|
alibi_slopes_,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
p_dropout,
|
||||||
|
softmax_scale,
|
||||||
|
zero_tensors,
|
||||||
|
is_causal,
|
||||||
|
window_size_left,
|
||||||
|
window_size_right,
|
||||||
|
return_softmax,
|
||||||
|
gen_);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd(
|
inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd(
|
||||||
const at::Tensor& dout, // batch_size x seqlen_q x num_heads, x head_size_og
|
const at::Tensor& dout, // batch_size x seqlen_q x num_heads, x head_size_og
|
||||||
|
|
@ -435,18 +468,16 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd(
|
||||||
const float p_dropout, // probability to drop
|
const float p_dropout, // probability to drop
|
||||||
const float softmax_scale,
|
const float softmax_scale,
|
||||||
const bool is_causal,
|
const bool is_causal,
|
||||||
std::optional<int64_t> window_size_left,
|
int window_size_left,
|
||||||
std::optional<int64_t> window_size_right,
|
int window_size_right,
|
||||||
const float softcap,
|
const float softcap,
|
||||||
const bool deterministic,
|
const bool deterministic,
|
||||||
const at::Tensor philox_seed,
|
const at::Tensor philox_seed,
|
||||||
const at::Tensor philox_offset) {
|
const at::Tensor philox_offset) {
|
||||||
|
#if defined(USE_CK_FLASH_ATTENTION)
|
||||||
if (at::globalContext().getROCmFAPreferredBackend() ==
|
if (at::globalContext().getROCmFAPreferredBackend() ==
|
||||||
at::ROCmFABackend::Ck) {
|
at::ROCmFABackend::Ck) {
|
||||||
#if defined(USE_CK_FLASH_ATTENTION)
|
|
||||||
std::optional<at::Tensor> non_null_dbias = std::nullopt;
|
std::optional<at::Tensor> non_null_dbias = std::nullopt;
|
||||||
const int non_null_window_left = window_size_left.value_or(-1);
|
|
||||||
const int non_null_window_right = window_size_right.value_or(-1);
|
|
||||||
auto[dQuery,
|
auto[dQuery,
|
||||||
dKey,
|
dKey,
|
||||||
dValue,
|
dValue,
|
||||||
|
|
@ -467,16 +498,38 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd(
|
||||||
p_dropout,
|
p_dropout,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
is_causal,
|
is_causal,
|
||||||
non_null_window_left,
|
window_size_left,
|
||||||
non_null_window_right,
|
window_size_right,
|
||||||
deterministic,
|
deterministic,
|
||||||
philox_seed,
|
philox_seed,
|
||||||
philox_offset);
|
philox_offset);
|
||||||
// for FA return [dQ, dV, dK, dSoftmax]
|
// for FA return [dQ, dV, dK, dSoftmax]
|
||||||
return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax));
|
return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax));
|
||||||
|
} else {
|
||||||
|
return mha_bwd_aot(
|
||||||
|
dout,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
softmax_lse,
|
||||||
|
dq_,
|
||||||
|
dk_,
|
||||||
|
dv_,
|
||||||
|
alibi_slopes_,
|
||||||
|
p_dropout,
|
||||||
|
softmax_scale,
|
||||||
|
is_causal,
|
||||||
|
window_size_left,
|
||||||
|
window_size_right,
|
||||||
|
deterministic,
|
||||||
|
philox_seed,
|
||||||
|
philox_offset);
|
||||||
|
}
|
||||||
#else
|
#else
|
||||||
|
if(at::globalContext().getROCmFAPreferredBackend() ==
|
||||||
|
at::ROCmFABackend::Ck) {
|
||||||
TORCH_WARN_ONCE("Warning! You have opted to use CK flash attention backend in a build that was not compiled using USE_CK_FLASH_ATTENTION=1. Please set this variable and try again. Defaulting to use aotriton backend...");
|
TORCH_WARN_ONCE("Warning! You have opted to use CK flash attention backend in a build that was not compiled using USE_CK_FLASH_ATTENTION=1. Please set this variable and try again. Defaulting to use aotriton backend...");
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
return mha_bwd_aot(
|
return mha_bwd_aot(
|
||||||
dout,
|
dout,
|
||||||
|
|
@ -497,6 +550,7 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd(
|
||||||
deterministic,
|
deterministic,
|
||||||
philox_seed,
|
philox_seed,
|
||||||
philox_offset);
|
philox_offset);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd(
|
inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd(
|
||||||
|
|
@ -524,8 +578,8 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd
|
||||||
const float softmax_scale,
|
const float softmax_scale,
|
||||||
const bool zero_tensors,
|
const bool zero_tensors,
|
||||||
const bool is_causal,
|
const bool is_causal,
|
||||||
std::optional<int64_t> window_size_left,
|
int window_size_left,
|
||||||
std::optional<int64_t> window_size_right,
|
int window_size_right,
|
||||||
const float softcap,
|
const float softcap,
|
||||||
const bool deterministic,
|
const bool deterministic,
|
||||||
const at::Tensor philox_seed,
|
const at::Tensor philox_seed,
|
||||||
|
|
@ -534,8 +588,6 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd
|
||||||
if (at::globalContext().getROCmFAPreferredBackend() ==
|
if (at::globalContext().getROCmFAPreferredBackend() ==
|
||||||
at::ROCmFABackend::Ck) {
|
at::ROCmFABackend::Ck) {
|
||||||
std::optional<at::Tensor> non_null_dbias = std::nullopt;
|
std::optional<at::Tensor> non_null_dbias = std::nullopt;
|
||||||
const int non_null_window_left = window_size_left.value_or(-1);
|
|
||||||
const int non_null_window_right = window_size_right.value_or(-1);
|
|
||||||
auto[dQuery,
|
auto[dQuery,
|
||||||
dKey,
|
dKey,
|
||||||
dValue,
|
dValue,
|
||||||
|
|
@ -561,15 +613,14 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
zero_tensors,
|
zero_tensors,
|
||||||
is_causal,
|
is_causal,
|
||||||
non_null_window_left,
|
window_size_left,
|
||||||
non_null_window_right,
|
window_size_right,
|
||||||
deterministic,
|
deterministic,
|
||||||
philox_seed,
|
philox_seed,
|
||||||
philox_offset);
|
philox_offset);
|
||||||
// for FA return [dQ, dV, dK, dSoftmax]
|
// for FA return [dQ, dV, dK, dSoftmax]
|
||||||
return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax));
|
return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax));
|
||||||
}
|
} else {
|
||||||
#endif
|
|
||||||
return mha_varlen_bwd_aot(
|
return mha_varlen_bwd_aot(
|
||||||
dout,
|
dout,
|
||||||
q,
|
q,
|
||||||
|
|
@ -595,5 +646,32 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd
|
||||||
philox_seed,
|
philox_seed,
|
||||||
philox_offset);
|
philox_offset);
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
return mha_varlen_bwd_aot(
|
||||||
|
dout,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
softmax_lse,
|
||||||
|
dq_,
|
||||||
|
dk_,
|
||||||
|
dv_,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
alibi_slopes_,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
p_dropout,
|
||||||
|
softmax_scale,
|
||||||
|
zero_tensors,
|
||||||
|
is_causal,
|
||||||
|
window_size_left,
|
||||||
|
window_size_right,
|
||||||
|
deterministic,
|
||||||
|
philox_seed,
|
||||||
|
philox_offset);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace pytorch_flash
|
} // namespace pytorch_flash
|
||||||
|
|
|
||||||
33
cmake/External/aotriton.cmake
vendored
33
cmake/External/aotriton.cmake
vendored
|
|
@ -1,3 +1,16 @@
|
||||||
|
macro(get_target_gpus_from_pytorch target_gpus)
|
||||||
|
set(gfx90a_key MI200)
|
||||||
|
set(gfx942_key MI300X)
|
||||||
|
set(gfx1100_key Navi31)
|
||||||
|
|
||||||
|
foreach(X IN LISTS PYTORCH_ROCM_ARCH)
|
||||||
|
set(key ${X})
|
||||||
|
string(APPEND key "_key")
|
||||||
|
string(APPEND target_gpus ${${key}})
|
||||||
|
string(APPEND target_gpus "|")
|
||||||
|
endforeach()
|
||||||
|
endmacro()
|
||||||
|
|
||||||
if(NOT __AOTRITON_INCLUDED)
|
if(NOT __AOTRITON_INCLUDED)
|
||||||
set(__AOTRITON_INCLUDED TRUE)
|
set(__AOTRITON_INCLUDED TRUE)
|
||||||
|
|
||||||
|
|
@ -9,22 +22,22 @@ if(NOT __AOTRITON_INCLUDED)
|
||||||
# Replaces .ci/docker/aotriton_version.txt
|
# Replaces .ci/docker/aotriton_version.txt
|
||||||
# Note packages information may have versions skipped (due to no ABI breaks)
|
# Note packages information may have versions skipped (due to no ABI breaks)
|
||||||
# But they must be listed from lower version to higher version
|
# But they must be listed from lower version to higher version
|
||||||
set(__AOTRITON_VER "0.10b")
|
set(__AOTRITON_VER "0.9.2b")
|
||||||
set(__AOTRITON_MANYLINUX_LIST
|
set(__AOTRITON_MANYLINUX_LIST
|
||||||
|
"manylinux_2_28" # rocm6.2
|
||||||
"manylinux_2_28" # rocm6.3
|
"manylinux_2_28" # rocm6.3
|
||||||
"manylinux_2_28" # rocm6.4
|
"manylinux_2_28" # rocm6.4
|
||||||
"manylinux_2_28" # rocm7.0
|
|
||||||
)
|
)
|
||||||
set(__AOTRITON_ROCM_LIST
|
set(__AOTRITON_ROCM_LIST
|
||||||
|
"rocm6.2"
|
||||||
"rocm6.3"
|
"rocm6.3"
|
||||||
"rocm6.4"
|
"rocm6.4"
|
||||||
"rocm7.0"
|
|
||||||
)
|
)
|
||||||
set(__AOTRITON_CI_COMMIT "6fca155f4deeb8d9529326f7b69f350aeeb93477")
|
set(__AOTRITON_CI_COMMIT "b388d223d8c7213545603e00f6f3148c54d1f525")
|
||||||
set(__AOTRITON_SHA256_LIST
|
set(__AOTRITON_SHA256_LIST
|
||||||
"861cd9f7479eec943933c27cb86920247e5b5dd139bc7c1376c81808abb7d7fe" # rocm6.3
|
"08d84f96f4c984179f80f517c0431c7511ee26bb0ce9bd05a827573ddd78cc79" # rocm6.2
|
||||||
"acea7d811a2d3bbe718b6e07fc2a9f739e49eecd60b4b6a36fcb3fe8edf85d78" # rocm6.4
|
"9094d59717e7e6eace9126ca100dd0e86510f07fc6c3a349569fc4e2d9056604" # rocm6.3
|
||||||
"7e29c325d5bd33ba896ddb106f5d4fc7d715274dca7fe937f724fffa82017838" # rocm7.0
|
"41190202c2736d5ff75b13a3abc0fb52ebfbb67226cf85dc3de7699c7000db44" # rocm6.4
|
||||||
)
|
)
|
||||||
set(__AOTRITON_Z "gz")
|
set(__AOTRITON_Z "gz")
|
||||||
|
|
||||||
|
|
@ -37,13 +50,17 @@ if(NOT __AOTRITON_INCLUDED)
|
||||||
set(__AOTRITON_INSTALL_DIR "$ENV{AOTRITON_INSTALLED_PREFIX}")
|
set(__AOTRITON_INSTALL_DIR "$ENV{AOTRITON_INSTALLED_PREFIX}")
|
||||||
message(STATUS "Using Preinstalled AOTriton at ${__AOTRITON_INSTALL_DIR}")
|
message(STATUS "Using Preinstalled AOTriton at ${__AOTRITON_INSTALL_DIR}")
|
||||||
elseif(DEFINED ENV{AOTRITON_INSTALL_FROM_SOURCE})
|
elseif(DEFINED ENV{AOTRITON_INSTALL_FROM_SOURCE})
|
||||||
|
set(target_gpus "")
|
||||||
|
get_target_gpus_from_pytorch(target_gpus)
|
||||||
ExternalProject_Add(aotriton_external
|
ExternalProject_Add(aotriton_external
|
||||||
GIT_REPOSITORY https://github.com/ROCm/aotriton.git
|
GIT_REPOSITORY https://github.com/ROCm/aotriton.git
|
||||||
GIT_TAG ${__AOTRITON_CI_COMMIT}
|
GIT_TAG ${__AOTRITON_CI_COMMIT}
|
||||||
PREFIX ${__AOTRITON_EXTERN_PREFIX}
|
PREFIX ${__AOTRITON_EXTERN_PREFIX}
|
||||||
INSTALL_DIR ${__AOTRITON_INSTALL_DIR}
|
INSTALL_DIR ${__AOTRITON_INSTALL_DIR}
|
||||||
|
LIST_SEPARATOR |
|
||||||
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR}
|
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR}
|
||||||
-DAOTRITON_TARGET_ARCH:STRING=${PYTORCH_ROCM_ARCH}
|
-DTARGET_GPUS:STRING=${target_gpus}
|
||||||
|
-DAOTRITON_COMPRESS_KERNEL=ON
|
||||||
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
|
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
|
||||||
-DAOTRITON_NO_PYTHON=ON
|
-DAOTRITON_NO_PYTHON=ON
|
||||||
-DAOTRITON_NO_SHARED=OFF
|
-DAOTRITON_NO_SHARED=OFF
|
||||||
|
|
|
||||||
|
|
@ -3294,7 +3294,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||||
fudge_factors['grad_key'] = 70.0
|
fudge_factors['grad_key'] = 70.0
|
||||||
if seq_len_k >= 2048:
|
if seq_len_k >= 2048:
|
||||||
fudge_factors['grad_key'] = 160.0
|
fudge_factors['grad_key'] = 160.0
|
||||||
fudge_factors['grad_query'] = 670.0
|
fudge_factors['grad_query'] = 650.0
|
||||||
if dtype == torch.float32:
|
if dtype == torch.float32:
|
||||||
fudge_factors['grad_key'] = 90.0
|
fudge_factors['grad_key'] = 90.0
|
||||||
|
|
||||||
|
|
@ -3415,7 +3415,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||||
fudge_factors['grad_key'] = 70.0
|
fudge_factors['grad_key'] = 70.0
|
||||||
if seq_len_k >= 2048:
|
if seq_len_k >= 2048:
|
||||||
fudge_factors['grad_key'] = 160.0
|
fudge_factors['grad_key'] = 160.0
|
||||||
fudge_factors['grad_query'] = 670.0 # gfx90a
|
fudge_factors['grad_query'] = 650.0
|
||||||
if dtype == torch.float32:
|
if dtype == torch.float32:
|
||||||
fudge_factors['grad_key'] = 90.0
|
fudge_factors['grad_key'] = 90.0
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -57,9 +57,9 @@ def CDNA2OrLater():
|
||||||
|
|
||||||
def evaluate_platform_supports_flash_attention():
|
def evaluate_platform_supports_flash_attention():
|
||||||
if TEST_WITH_ROCM:
|
if TEST_WITH_ROCM:
|
||||||
arch_list = ["gfx90a", "gfx942", "gfx1100", "gfx1201", "gfx950"]
|
arch_list = ["gfx90a", "gfx942", "gfx1100"]
|
||||||
if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0":
|
if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0":
|
||||||
arch_list += ["gfx1101", "gfx1150", "gfx1151", "gfx1200"]
|
arch_list += ["gfx1201", "gfx950"]
|
||||||
return evaluate_gfx_arch_within(arch_list)
|
return evaluate_gfx_arch_within(arch_list)
|
||||||
if TEST_CUDA:
|
if TEST_CUDA:
|
||||||
return not IS_WINDOWS and SM80OrLater
|
return not IS_WINDOWS and SM80OrLater
|
||||||
|
|
@ -67,9 +67,9 @@ def evaluate_platform_supports_flash_attention():
|
||||||
|
|
||||||
def evaluate_platform_supports_efficient_attention():
|
def evaluate_platform_supports_efficient_attention():
|
||||||
if TEST_WITH_ROCM:
|
if TEST_WITH_ROCM:
|
||||||
arch_list = ["gfx90a", "gfx942", "gfx1100", "gfx1201", "gfx950"]
|
arch_list = ["gfx90a", "gfx942", "gfx1100"]
|
||||||
if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0":
|
if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0":
|
||||||
arch_list += ["gfx1101", "gfx1150", "gfx1151", "gfx1200"]
|
arch_list += ["gfx1201", "gfx950"]
|
||||||
return evaluate_gfx_arch_within(arch_list)
|
return evaluate_gfx_arch_within(arch_list)
|
||||||
if TEST_CUDA:
|
if TEST_CUDA:
|
||||||
return True
|
return True
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user