Revert "[ROCm] Bump AOTriton to 0.10b (#156290)"

This reverts commit 34d8e64ef6.

Reverted https://github.com/pytorch/pytorch/pull/156290 on behalf of https://github.com/atalman due to failing multiple internal tests ([comment](https://github.com/pytorch/pytorch/pull/156290#issuecomment-2992072727))
This commit is contained in:
PyTorch MergeBot 2025-06-20 15:35:25 +00:00
parent b4442f42a9
commit 1036f6d114
7 changed files with 241 additions and 368 deletions

View File

@ -1113,10 +1113,8 @@ _flash_attention_forward(
std::optional<Tensor> alibi_slopes = _alibi_slopes; std::optional<Tensor> alibi_slopes = _alibi_slopes;
const float softcap = 0.0; const float softcap = 0.0;
#ifndef USE_ROCM // ROCM backend accepts std::optional for window_size_left/right directly. const int non_null_window_left = window_size_left.has_value() ? window_size_left.value() : -1;
const int non_null_window_left = window_size_left.value_or(-1); const int non_null_window_right = window_size_right.has_value() ? window_size_right.value() : -1;
const int non_null_window_right = window_size_right.value_or(-1);
#endif
// We are going to have two paths: // We are going to have two paths:
// 1. The standard MHA path for dense tensors // 1. The standard MHA path for dense tensors
@ -1153,13 +1151,8 @@ _flash_attention_forward(
softmax_scale, softmax_scale,
false /*zero_tensors*/, false /*zero_tensors*/,
is_causal, is_causal,
#ifdef USE_ROCM
window_size_left,
window_size_right,
#else
non_null_window_left, non_null_window_left,
non_null_window_right, non_null_window_right,
#endif
softcap, softcap,
return_debug_mask, return_debug_mask,
std::nullopt /*gen_*/); std::nullopt /*gen_*/);
@ -1182,13 +1175,8 @@ _flash_attention_forward(
dropout_p, dropout_p,
softmax_scale, softmax_scale,
is_causal, is_causal,
#ifdef USE_ROCM
window_size_left,
window_size_right,
#else
non_null_window_left, non_null_window_left,
non_null_window_right, non_null_window_right,
#endif
softcap, softcap,
return_debug_mask, /*return_softmax (this is used for testing)*/ return_debug_mask, /*return_softmax (this is used for testing)*/
std::nullopt); std::nullopt);

View File

@ -87,10 +87,8 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
auto contiguous_grad_out = grad_out.contiguous(); auto contiguous_grad_out = grad_out.contiguous();
auto contiguous_out = out.contiguous(); auto contiguous_out = out.contiguous();
#ifndef USE_ROCM // ROCM backend accepts std::optional for window_size_left/right directly.
const int non_null_window_left = window_size_left.has_value() ? window_size_left.value() : -1; const int non_null_window_left = window_size_left.has_value() ? window_size_left.value() : -1;
const int non_null_window_right = window_size_right.has_value() ? window_size_right.value() : -1; const int non_null_window_right = window_size_right.has_value() ? window_size_right.value() : -1;
#endif
std::optional<at::Tensor> dq{std::nullopt}; std::optional<at::Tensor> dq{std::nullopt};
std::optional<at::Tensor> dk{std::nullopt}; std::optional<at::Tensor> dk{std::nullopt};
@ -138,13 +136,8 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
softmax_scale, softmax_scale,
false /*zero_tensors*/, false /*zero_tensors*/,
is_causal, is_causal,
#ifdef USE_ROCM
window_size_left,
window_size_right,
#else
non_null_window_left, non_null_window_left,
non_null_window_right, non_null_window_right,
#endif
softcap, softcap,
determinisitic, determinisitic,
philox_seed, philox_seed,
@ -166,13 +159,8 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
dropout_p, dropout_p,
softmax_scale, softmax_scale,
is_causal, is_causal,
#ifdef USE_ROCM
window_size_left,
window_size_right,
#else
non_null_window_left, non_null_window_left,
non_null_window_right, non_null_window_right,
#endif
softcap, softcap,
determinisitic, determinisitic,
philox_seed, philox_seed,

View File

@ -64,14 +64,8 @@
#include <aotriton/flash.h> #include <aotriton/flash.h>
#include <aotriton/runtime.h> #include <aotriton/runtime.h>
#if AOTRITON_VERSION_MINOR < 9 #if AOTRITON_VERSION_MINOR != 9
#error "This adaptor code is only tested with AOTriton >= 0.9" #error "This adaptor code is only tested with AOTriton 0.9.x"
#endif
#if (AOTRITON_VERSION_MAJOR * 100 + AOTRITON_VERSION_MINOR) >= 10
#define V3_API 1
#else
#define V3_API 0
#endif #endif
namespace pytorch_flash { namespace pytorch_flash {
@ -87,38 +81,6 @@ void check_gpu_arch(hipStream_t stream) {
} }
} }
std::tuple<bool, int, int>
calculate_swa(std::optional<int64_t> window_size_left,
std::optional<int64_t> window_size_right,
int max_seqlen_q,
int max_seqlen_k,
bool is_causal) {
#if V3_API // SWA is exposed through V3 API
bool needs_swa = false;
using aotriton::v3::flash::WindowValue;
// Default values when std::optional window_size_left/right have no value
int window_left = max_seqlen_q;
int window_right = max_seqlen_k;
if (is_causal) {
window_left = WindowValue::TopLeftAligned;
window_right = WindowValue::TopLeftAligned;
}
if (window_size_left.has_value() || window_size_right.has_value()) {
needs_swa = true;
window_left = window_size_left.value_or(window_left);
window_right = window_size_right.value_or(window_right);
}
return std::make_tuple(needs_swa, window_left, window_right);
#else
if (window_size_left.has_value() || window_size_right.has_value()) {
TORCH_WARN_ONCE("Current AOTriton does not support sliding window attention (SWA)."
" Both window_size_left and window_size_right will be ignored."
" Re-compile PyTorch with AOTriton >= 0.10b to enable SWA support.");
}
return std::make_tuple(false, 0, 0);
#endif
}
// We want to checkpoint and save the RNG state for backward if dropout // We want to checkpoint and save the RNG state for backward if dropout
// We get the default generator and return the seed and offset which will // We get the default generator and return the seed and offset which will
// be used in the backward function // be used in the backward function
@ -165,8 +127,8 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x
const float p_dropout, const float p_dropout,
const float softmax_scale, const float softmax_scale,
bool is_causal, bool is_causal,
std::optional<int64_t> window_size_left, int window_size_left,
std::optional<int64_t> window_size_right, int window_size_right,
const bool return_softmax, const bool return_softmax,
const std::optional<at::Generator>& gen_) { const std::optional<at::Generator>& gen_) {
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
@ -199,6 +161,7 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case
if (is_causal) { window_size_right = 0; }
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
@ -249,19 +212,6 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x
atomic_counter = at::zeros({1}, opts.dtype(at::kInt)); atomic_counter = at::zeros({1}, opts.dtype(at::kInt));
} }
auto [needs_swa, window_left, window_right] = calculate_swa(window_size_left,
window_size_right,
seqlen_q,
seqlen_k,
is_causal);
#if V3_API
const bool uses_swa = needs_swa;
#else
// When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be
// optimized out (hopefully).
constexpr bool uses_swa = false;
#endif
hipError_t err; // TODO: Error handling hipError_t err; // TODO: Error handling
using aotriton::v2::flash::attn_fwd; using aotriton::v2::flash::attn_fwd;
using sdp::aotriton_adapter::mk_aotensor; using sdp::aotriton_adapter::mk_aotensor;
@ -276,54 +226,23 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x
auto seed_output = mk_philoxtensor(use_philox_state ? seed_t.data_ptr<int64_t>() : nullptr); auto seed_output = mk_philoxtensor(use_philox_state ? seed_t.data_ptr<int64_t>() : nullptr);
auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr<int64_t>() : nullptr); auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr<int64_t>() : nullptr);
auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr<int32_t>() : nullptr); auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr<int32_t>() : nullptr);
if (uses_swa) { err = attn_fwd(mk_aotensor(q_t, "q"),
#if V3_API mk_aotensor(k_t, "k"),
using aotriton::v3::flash::CausalType; mk_aotensor(v_t, "v"),
using aotriton::v3::flash::VarlenType; empty_bias,
aotriton::v3::flash::attn_fwd_params params; softmax_scale,
params.Q = mk_aotensor(q_t, "q"); mk_aotensor<2>(M, "M"),
params.K = mk_aotensor(k_t, "k"); mk_aotensor(output_t, "Out"),
params.V = mk_aotensor(v_t, "v"); p_dropout,
params.Sm_scale = softmax_scale; seed,
params.L = mk_aotensor<2>(M, "M"); offset1,
params.Out = mk_aotensor(output_t, "Out"); offset2,
params.Max_seqlen_q = seqlen_q; // Unused if cu_seqlens_q is empty seed_output,
params.Max_seqlen_k = seqlen_k; // Unused if cu_seqlens_k is empty offset_output,
params.dropout_p = p_dropout; mk_aotensor(softmax_fa_t, "encoded_softmax"),
params.philox_seed_ptr = seed; is_causal,
params.philox_offset1 = offset1; persistent_counter,
params.philox_offset2 = offset2; stream);
params.philox_seed_output = seed_output;
params.philox_offset_output = offset_output;
params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax");
params.persistent_atomic_counter = persistent_counter;
params.causal_type = CausalType::WindowedAttention;
params.varlen_type = VarlenType::None;
params.window_left = window_left;
params.window_right = window_right;
err = aotriton::v3::flash::attn_fwd(params,
aotriton::v3::flash::attn_fwd_params::kVersion,
stream);
#endif
} else {
err = attn_fwd(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
empty_bias,
softmax_scale,
mk_aotensor<2>(M, "M"),
mk_aotensor(output_t, "Out"),
p_dropout,
seed,
offset1,
offset2,
seed_output,
offset_output,
mk_aotensor(softmax_fa_t, "encoded_softmax"),
is_causal,
persistent_counter,
stream);
}
return {out, q_padded, k_padded, v_padded, M.view({batch_size, num_heads, seqlen_q}), seed_t, offset_t, softmax_fa_t}; return {out, q_padded, k_padded, v_padded, M.view({batch_size, num_heads, seqlen_q}), seed_t, offset_t, softmax_fa_t};
} }
@ -344,8 +263,8 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
const float softmax_scale, const float softmax_scale,
const bool zero_tensors, const bool zero_tensors,
bool is_causal, bool is_causal,
std::optional<int64_t> window_size_left, int window_size_left,
std::optional<int64_t> window_size_right, int window_size_right,
const bool return_softmax, const bool return_softmax,
const std::optional<at::Generator>& gen_) { const std::optional<at::Generator>& gen_) {
TORCH_CHECK(!seqused_k.has_value(), "[ROCm] mha_varlen_fwd: seqused_k must be nullopt"); TORCH_CHECK(!seqused_k.has_value(), "[ROCm] mha_varlen_fwd: seqused_k must be nullopt");
@ -393,6 +312,13 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
TORCH_CHECK(head_size_og <= 512, "FlashAttention on ROCm forward only supports head dimension at most 512"); TORCH_CHECK(head_size_og <= 512, "FlashAttention on ROCm forward only supports head dimension at most 512");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (window_size_left >= max_seqlen_k) {
window_size_left = -1;
}
if (window_size_right >= max_seqlen_k) {
window_size_right = -1;
}
CHECK_SHAPE(temp_q, total_q, num_heads, head_size_og); CHECK_SHAPE(temp_q, total_q, num_heads, head_size_og);
const int total_k = k.size(0); const int total_k = k.size(0);
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
@ -442,19 +368,6 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
} }
} }
auto [needs_swa, window_left, window_right] = calculate_swa(window_size_left,
window_size_right,
max_seqlen_q,
max_seqlen_k,
is_causal);
#if V3_API
const bool uses_swa = needs_swa;
#else
// When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be
// optimized out (hopefully).
constexpr bool uses_swa = false;
#endif
auto [seed_t, offset_t, philox_state, use_philox_state] = auto [seed_t, offset_t, philox_state, use_philox_state] =
prepare_philox_arguments(p_dropout, batch_size * num_heads * 32); prepare_philox_arguments(p_dropout, batch_size * num_heads * 32);
@ -477,58 +390,27 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr<int64_t>()) : nullscalar; auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr<int64_t>()) : nullscalar;
auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr<int64_t>()) : nullscalar; auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr<int64_t>()) : nullscalar;
auto persistent_counter = is_causal ? mk_philoxtensor(atomic_counter.data_ptr<int64_t>()) : nullscalar; auto persistent_counter = is_causal ? mk_philoxtensor(atomic_counter.data_ptr<int64_t>()) : nullscalar;
if (uses_swa) { err = attn_fwd_compact_varlen(mk_aotensor(q_padded, "q"),
using aotriton::v3::flash::CausalType; mk_aotensor(k_padded, "k"),
using aotriton::v3::flash::VarlenType; mk_aotensor(v_padded, "v"),
aotriton::v3::flash::attn_fwd_params params; empty_bias,
params.Q = mk_aotensor(q_padded, "q"); mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"),
params.K = mk_aotensor(k_padded, "k"); mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"),
params.V = mk_aotensor(v_padded, "v"); max_seqlen_q,
params.Sm_scale = softmax_scale; max_seqlen_k,
params.L = mk_aotensor<2>(M, "M"); softmax_scale,
params.Out = mk_aotensor(out_padded, "Out"); mk_aotensor<2>(M, "M"),
params.cu_seqlens_q = mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"); mk_aotensor(out_padded, "Out"),
params.cu_seqlens_k = mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"); p_dropout,
params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty seed,
params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty offset1,
params.dropout_p = p_dropout; offset2,
params.philox_seed_ptr = seed; seed_output,
params.philox_offset1 = offset1; offset_output,
params.philox_offset2 = offset2; mk_aotensor(softmax_fa_t, "encoded_softmax"),
params.philox_seed_output = seed_output; is_causal,
params.philox_offset_output = offset_output; persistent_counter,
params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax"); stream);
params.persistent_atomic_counter = persistent_counter;
params.causal_type = CausalType::WindowedAttention;
params.varlen_type = VarlenType::CompactVarlen;
params.window_left = window_left;
params.window_right = window_right;
err = aotriton::v3::flash::attn_fwd(params,
aotriton::v3::flash::attn_fwd_params::kVersion,
stream);
} else {
err = attn_fwd_compact_varlen(mk_aotensor(q_padded, "q"),
mk_aotensor(k_padded, "k"),
mk_aotensor(v_padded, "v"),
empty_bias,
mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"),
mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"),
max_seqlen_q,
max_seqlen_k,
softmax_scale,
mk_aotensor<2>(M, "M"),
mk_aotensor(out_padded, "Out"),
p_dropout,
seed,
offset1,
offset2,
seed_output,
offset_output,
mk_aotensor(softmax_fa_t, "encoded_softmax"),
is_causal,
persistent_counter,
stream);
}
} else { } else {
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
out.zero_(); out.zero_();
@ -552,8 +434,8 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea
const float p_dropout, // probability to drop const float p_dropout, // probability to drop
const float softmax_scale, const float softmax_scale,
const bool is_causal, const bool is_causal,
std::optional<int64_t> window_size_left, int window_size_left,
std::optional<int64_t> window_size_right, int window_size_right,
const bool deterministic, const bool deterministic,
const at::Tensor& philox_seed, const at::Tensor& philox_seed,
const at::Tensor& philox_offset) { const at::Tensor& philox_offset) {
@ -642,19 +524,6 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea
dv = at::empty_like(k); dv = at::empty_like(k);
} }
auto [needs_swa, window_left, window_right] = calculate_swa(window_size_left,
window_size_right,
seqlen_q,
seqlen_k,
is_causal);
#if V3_API
const bool uses_swa = needs_swa;
#else
// When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be
// optimized out (hopefully).
constexpr bool uses_swa = false;
#endif
auto opts = q.options(); auto opts = q.options();
auto softmax_d = at::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); auto softmax_d = at::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
@ -672,40 +541,10 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea
int d_head = head_size_og; int d_head = head_size_og;
bool use_fused_bwd = d_head <= 192 && d_head * seqlen_q < 64 * 512; bool use_fused_bwd = d_head <= 192 && d_head * seqlen_q < 64 * 512;
hipError_t err; // TODO: Error handling hipError_t err; // TODO: Error handling
using sdp::aotriton_adapter::mk_aotensor; if (use_fused_bwd) {
using sdp::aotriton_adapter::mk_aoscalartensor;
if (uses_swa) {
// Fused BWD does not support SWA
at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous();
using aotriton::v3::flash::CausalType;
using aotriton::v3::flash::VarlenType;
aotriton::v3::flash::attn_bwd_params params;
params.Q = mk_aotensor(q_t, "q");
params.K = mk_aotensor(k_t, "k");
params.V = mk_aotensor(v_t, "v");
params.Sm_scale = softmax_scale;
params.Out = mk_aotensor(out_t, "out");
params.DO = mk_aotensor(dout_t, "dout");
params.DK = mk_aotensor(dq_t, "dq");
params.DV = mk_aotensor(dk_t, "dk");
params.DQ = mk_aotensor(dv_t, "dv");
params.L = mk_aotensor<2>(softmax_lse_cont, "L");
params.D = mk_aotensor<2>(delta, "delta");
params.Max_seqlen_q = seqlen_q; // Unused if cu_seqlens_q is empty
params.Max_seqlen_k = seqlen_k; // Unused if cu_seqlens_k is empty
params.dropout_p = p_dropout;
params.philox_seed_ptr = mk_aoscalartensor(philox_seed);
params.philox_offset1 = mk_aoscalartensor(philox_offset);
params.philox_offset2 = 0;
params.causal_type = CausalType::WindowedAttention;
params.varlen_type = VarlenType::None;
params.window_left = window_left;
params.window_right = window_right;
err = aotriton::v3::flash::attn_bwd(params,
aotriton::v3::flash::attn_bwd_params::kVersion,
stream);
} else if (use_fused_bwd) {
using aotriton::v2::flash::attn_bwd_fused; using aotriton::v2::flash::attn_bwd_fused;
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
using sdp::aotriton_adapter::cast_dtype; using sdp::aotriton_adapter::cast_dtype;
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
err = attn_bwd_fused(mk_aotensor(q_t, "q"), err = attn_bwd_fused(mk_aotensor(q_t, "q"),
@ -729,6 +568,8 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea
} else { } else {
at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous(); at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous();
using aotriton::v2::flash::attn_bwd; using aotriton::v2::flash::attn_bwd;
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
using sdp::aotriton_adapter::cast_dtype; using sdp::aotriton_adapter::cast_dtype;
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
err = attn_bwd(mk_aotensor(q_t, "q"), err = attn_bwd(mk_aotensor(q_t, "q"),
@ -774,14 +615,17 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
const float softmax_scale, const float softmax_scale,
const bool zero_tensors, const bool zero_tensors,
const bool is_causal, const bool is_causal,
std::optional<int64_t> window_size_left, int window_size_left,
std::optional<int64_t> window_size_right, int window_size_right,
const bool deterministic, const bool deterministic,
const at::Tensor& philox_seed, const at::Tensor& philox_seed,
const at::Tensor& philox_offset) const at::Tensor& philox_offset)
{ {
TORCH_CHECK(!alibi_slopes_.has_value(), "[ROCm] mha_varlen_fwd: alibi_slopes_ must be nullopt"); TORCH_CHECK(!alibi_slopes_.has_value(), "[ROCm] mha_varlen_fwd: alibi_slopes_ must be nullopt");
if (is_causal) {
window_size_right = 0;
}
// Otherwise the kernel will be launched from cuda:0 device // Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing // Cast to char to avoid compiler warning about narrowing
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
@ -825,6 +669,9 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
TORCH_CHECK(head_size <= 512, "FlashAttention on ROCm backward only supports head dimension at most 512"); TORCH_CHECK(head_size <= 512, "FlashAttention on ROCm backward only supports head dimension at most 512");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
CHECK_SHAPE(q, total_q, num_heads, head_size); CHECK_SHAPE(q, total_q, num_heads, head_size);
CHECK_SHAPE(k, total_k, num_heads_k, head_size); CHECK_SHAPE(k, total_k, num_heads_k, head_size);
CHECK_SHAPE(v, total_k, num_heads_k, head_size); CHECK_SHAPE(v, total_k, num_heads_k, head_size);
@ -887,19 +734,6 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
softmax_d.zero_(); softmax_d.zero_();
} }
auto [needs_swa, window_left, window_right] = calculate_swa(window_size_left,
window_size_right,
max_seqlen_q,
max_seqlen_k,
is_causal);
#if V3_API
const bool uses_swa = needs_swa;
#else
// When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be
// optimized out (hopefully).
constexpr bool uses_swa = false;
#endif
at::PhiloxCudaState philox_args; at::PhiloxCudaState philox_args;
if (is_dropout) { if (is_dropout) {
if (at::cuda::currentStreamCaptureStatus() == if (at::cuda::currentStreamCaptureStatus() ==
@ -913,66 +747,34 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
} }
if (max_seqlen_q > 0) { if (max_seqlen_q > 0) {
hipError_t err; // TODO: Error handling hipError_t err; // TODO: Error handling
using aotriton::v2::flash::attn_bwd_compact_varlen;
using sdp::aotriton_adapter::mk_aotensor; using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor; using sdp::aotriton_adapter::mk_aoscalartensor;
if (uses_swa) { using sdp::aotriton_adapter::cast_dtype;
using aotriton::v3::flash::CausalType; aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
using aotriton::v3::flash::VarlenType; err = attn_bwd_compact_varlen(mk_aotensor(q_padded, "q"),
aotriton::v3::flash::attn_bwd_params params; mk_aotensor(k_padded, "k"),
params.Q = mk_aotensor(q_padded, "q"); mk_aotensor(v_padded, "v"),
params.K = mk_aotensor(k_padded, "k"); mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"),
params.V = mk_aotensor(v_padded, "v"); mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"),
params.Sm_scale = softmax_scale; max_seqlen_q,
params.Out = mk_aotensor(out_t, "out"); max_seqlen_k,
params.DO = mk_aotensor(dout_t, "dout"); empty_bias,
params.DK = mk_aotensor(dq_padded, "dq"); softmax_scale,
params.DV = mk_aotensor(dk_padded, "dk"); mk_aotensor(out_t, "out"),
params.DQ = mk_aotensor(dv_padded, "dv"); mk_aotensor(dout_t, "dout"),
params.L = mk_aotensor<2>(softmax_lse_cont, "L"); mk_aotensor(dq_padded, "dq"),
params.D = mk_aotensor<2>(delta, "delta"); mk_aotensor(dk_padded, "dk"),
params.cu_seqlens_q = mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"); mk_aotensor(dv_padded, "dv"),
params.cu_seqlens_k = mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"); empty_bias,
params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty mk_aotensor<2>(softmax_lse_cont, "L"),
params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty mk_aotensor<2>(delta, "delta"),
params.dropout_p = p_dropout; p_dropout,
params.philox_seed_ptr = mk_aoscalartensor(philox_seed); mk_aoscalartensor(philox_seed),
params.philox_offset1 = mk_aoscalartensor(philox_offset); mk_aoscalartensor(philox_offset),
params.philox_offset2 = 0; 0,
params.causal_type = CausalType::WindowedAttention; is_causal,
params.varlen_type = VarlenType::CompactVarlen; stream);
params.window_left = window_left;
params.window_right = window_right;
err = aotriton::v3::flash::attn_bwd(params,
aotriton::v3::flash::attn_bwd_params::kVersion,
stream);
} else {
using aotriton::v2::flash::attn_bwd_compact_varlen;
using sdp::aotriton_adapter::cast_dtype;
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
err = attn_bwd_compact_varlen(mk_aotensor(q_padded, "q"),
mk_aotensor(k_padded, "k"),
mk_aotensor(v_padded, "v"),
mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"),
mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"),
max_seqlen_q,
max_seqlen_k,
empty_bias,
softmax_scale,
mk_aotensor(out_t, "out"),
mk_aotensor(dout_t, "dout"),
mk_aotensor(dq_padded, "dq"),
mk_aotensor(dk_padded, "dk"),
mk_aotensor(dv_padded, "dv"),
empty_bias,
mk_aotensor<2>(softmax_lse_cont, "L"),
mk_aotensor<2>(delta, "delta"),
p_dropout,
mk_aoscalartensor(philox_seed),
mk_aoscalartensor(philox_offset),
0,
is_causal,
stream);
}
} else { } else {
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
dq.zero_(); dq.zero_();

View File

@ -51,8 +51,8 @@ mha_fwd_aot(
const float p_dropout, const float p_dropout,
const float softmax_scale, const float softmax_scale,
bool is_causal, bool is_causal,
std::optional<int64_t> window_size_left, int window_size_left,
std::optional<int64_t> window_size_right, int window_size_right,
const bool return_softmax, const bool return_softmax,
const std::optional<at::Generator>& gen_); const std::optional<at::Generator>& gen_);
@ -87,8 +87,8 @@ mha_varlen_fwd_aot(
const float softmax_scale, const float softmax_scale,
const bool zero_tensors, const bool zero_tensors,
bool is_causal, bool is_causal,
std::optional<int64_t> window_size_left, int window_size_left,
std::optional<int64_t> window_size_right, int window_size_right,
const bool return_softmax, const bool return_softmax,
const std::optional<at::Generator>& gen_); const std::optional<at::Generator>& gen_);
@ -110,8 +110,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd_aot(
const float p_dropout, // probability to drop const float p_dropout, // probability to drop
const float softmax_scale, const float softmax_scale,
const bool is_causal, const bool is_causal,
std::optional<int64_t> window_size_left, int window_size_left,
std::optional<int64_t> window_size_right, int window_size_right,
const bool deterministic, const bool deterministic,
const at::Tensor& philox_seed, const at::Tensor& philox_seed,
const at::Tensor& philox_offset); const at::Tensor& philox_offset);
@ -141,8 +141,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd_aot(
const float softmax_scale, const float softmax_scale,
const bool zero_tensors, const bool zero_tensors,
const bool is_causal, const bool is_causal,
std::optional<int64_t> window_size_left, int window_size_left,
std::optional<int64_t> window_size_right, int window_size_right,
const bool deterministic, const bool deterministic,
const at::Tensor& philox_seed, const at::Tensor& philox_seed,
const at::Tensor& philox_offset); const at::Tensor& philox_offset);
@ -290,16 +290,14 @@ mha_fwd(
const float p_dropout, const float p_dropout,
const float softmax_scale, const float softmax_scale,
bool is_causal, bool is_causal,
std::optional<int64_t> window_size_left, int window_size_left,
std::optional<int64_t> window_size_right, int window_size_right,
const float softcap, const float softcap,
const bool return_softmax, const bool return_softmax,
std::optional<at::Generator> gen_) { std::optional<at::Generator> gen_) {
#if defined(USE_CK_FLASH_ATTENTION) #if defined(USE_CK_FLASH_ATTENTION)
if (at::globalContext().getROCmFAPreferredBackend() == if (at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) { at::ROCmFABackend::Ck) {
const int non_null_window_left = window_size_left.value_or(-1);
const int non_null_window_right = window_size_right.value_or(-1);
std::optional<at::Tensor> dummy_attn_bias = std::nullopt; std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
return mha_fwd_ck( return mha_fwd_ck(
q, q,
@ -309,13 +307,27 @@ mha_fwd(
p_dropout, p_dropout,
softmax_scale, softmax_scale,
is_causal, is_causal,
non_null_window_left, window_size_left,
non_null_window_right, window_size_right,
return_softmax, return_softmax,
gen_, gen_,
dummy_attn_bias); // Not used in flash attention dummy_attn_bias); // Not used in flash attention
} else {
return mha_fwd_aot(
q,
k,
v,
out_,
alibi_slopes_,
p_dropout,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
return_softmax,
gen_);
} }
#endif #else
return mha_fwd_aot( return mha_fwd_aot(
q, q,
k, k,
@ -329,6 +341,7 @@ mha_fwd(
window_size_right, window_size_right,
return_softmax, return_softmax,
gen_); gen_);
#endif
} }
inline std::tuple< inline std::tuple<
@ -363,8 +376,8 @@ mha_varlen_fwd(
const float softmax_scale, const float softmax_scale,
const bool zero_tensors, const bool zero_tensors,
bool is_causal, bool is_causal,
std::optional<int64_t> window_size_left, int window_size_left,
std::optional<int64_t> window_size_right, int window_size_right,
const float softcap, const float softcap,
const bool return_softmax, const bool return_softmax,
std::optional<at::Generator> gen_) { std::optional<at::Generator> gen_) {
@ -372,8 +385,6 @@ mha_varlen_fwd(
if (at::globalContext().getROCmFAPreferredBackend() == if (at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) { at::ROCmFABackend::Ck) {
std::optional<at::Tensor> dummy_attn_bias = std::nullopt; std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
const int non_null_window_left = window_size_left.value_or(-1);
const int non_null_window_right = window_size_right.value_or(-1);
return mha_varlen_fwd_ck( return mha_varlen_fwd_ck(
q, q,
k, k,
@ -388,13 +399,34 @@ mha_varlen_fwd(
softmax_scale, softmax_scale,
zero_tensors, zero_tensors,
is_causal, is_causal,
non_null_window_left, window_size_left,
non_null_window_right, window_size_right,
return_softmax, return_softmax,
gen_, gen_,
dummy_attn_bias); // Not used in flash attention dummy_attn_bias); // Not used in flash attention
} else {
return mha_varlen_fwd_aot(
q,
k,
v,
out_,
cu_seqlens_q,
cu_seqlens_k,
seqused_k,
block_table_,
alibi_slopes_,
max_seqlen_q,
max_seqlen_k,
p_dropout,
softmax_scale,
zero_tensors,
is_causal,
window_size_left,
window_size_right,
return_softmax,
gen_);
} }
#endif #else
return mha_varlen_fwd_aot( return mha_varlen_fwd_aot(
q, q,
k, k,
@ -415,6 +447,7 @@ mha_varlen_fwd(
window_size_right, window_size_right,
return_softmax, return_softmax,
gen_); gen_);
#endif
} }
inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd( inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd(
@ -435,18 +468,16 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd(
const float p_dropout, // probability to drop const float p_dropout, // probability to drop
const float softmax_scale, const float softmax_scale,
const bool is_causal, const bool is_causal,
std::optional<int64_t> window_size_left, int window_size_left,
std::optional<int64_t> window_size_right, int window_size_right,
const float softcap, const float softcap,
const bool deterministic, const bool deterministic,
const at::Tensor philox_seed, const at::Tensor philox_seed,
const at::Tensor philox_offset) { const at::Tensor philox_offset) {
#if defined(USE_CK_FLASH_ATTENTION)
if (at::globalContext().getROCmFAPreferredBackend() == if (at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) { at::ROCmFABackend::Ck) {
#if defined(USE_CK_FLASH_ATTENTION)
std::optional<at::Tensor> non_null_dbias = std::nullopt; std::optional<at::Tensor> non_null_dbias = std::nullopt;
const int non_null_window_left = window_size_left.value_or(-1);
const int non_null_window_right = window_size_right.value_or(-1);
auto[dQuery, auto[dQuery,
dKey, dKey,
dValue, dValue,
@ -467,16 +498,38 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd(
p_dropout, p_dropout,
softmax_scale, softmax_scale,
is_causal, is_causal,
non_null_window_left, window_size_left,
non_null_window_right, window_size_right,
deterministic, deterministic,
philox_seed, philox_seed,
philox_offset); philox_offset);
// for FA return [dQ, dV, dK, dSoftmax] // for FA return [dQ, dV, dK, dSoftmax]
return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax)); return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax));
} else {
return mha_bwd_aot(
dout,
q,
k,
v,
out,
softmax_lse,
dq_,
dk_,
dv_,
alibi_slopes_,
p_dropout,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
deterministic,
philox_seed,
philox_offset);
}
#else #else
if(at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) {
TORCH_WARN_ONCE("Warning! You have opted to use CK flash attention backend in a build that was not compiled using USE_CK_FLASH_ATTENTION=1. Please set this variable and try again. Defaulting to use aotriton backend..."); TORCH_WARN_ONCE("Warning! You have opted to use CK flash attention backend in a build that was not compiled using USE_CK_FLASH_ATTENTION=1. Please set this variable and try again. Defaulting to use aotriton backend...");
#endif
} }
return mha_bwd_aot( return mha_bwd_aot(
dout, dout,
@ -497,6 +550,7 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd(
deterministic, deterministic,
philox_seed, philox_seed,
philox_offset); philox_offset);
#endif
} }
inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd( inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd(
@ -524,8 +578,8 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd
const float softmax_scale, const float softmax_scale,
const bool zero_tensors, const bool zero_tensors,
const bool is_causal, const bool is_causal,
std::optional<int64_t> window_size_left, int window_size_left,
std::optional<int64_t> window_size_right, int window_size_right,
const float softcap, const float softcap,
const bool deterministic, const bool deterministic,
const at::Tensor philox_seed, const at::Tensor philox_seed,
@ -534,8 +588,6 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd
if (at::globalContext().getROCmFAPreferredBackend() == if (at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) { at::ROCmFABackend::Ck) {
std::optional<at::Tensor> non_null_dbias = std::nullopt; std::optional<at::Tensor> non_null_dbias = std::nullopt;
const int non_null_window_left = window_size_left.value_or(-1);
const int non_null_window_right = window_size_right.value_or(-1);
auto[dQuery, auto[dQuery,
dKey, dKey,
dValue, dValue,
@ -561,15 +613,40 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd
softmax_scale, softmax_scale,
zero_tensors, zero_tensors,
is_causal, is_causal,
non_null_window_left, window_size_left,
non_null_window_right, window_size_right,
deterministic, deterministic,
philox_seed, philox_seed,
philox_offset); philox_offset);
// for FA return [dQ, dV, dK, dSoftmax] // for FA return [dQ, dV, dK, dSoftmax]
return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax)); return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax));
} else {
return mha_varlen_bwd_aot(
dout,
q,
k,
v,
out,
softmax_lse,
dq_,
dk_,
dv_,
cu_seqlens_q,
cu_seqlens_k,
alibi_slopes_,
max_seqlen_q,
max_seqlen_k,
p_dropout,
softmax_scale,
zero_tensors,
is_causal,
window_size_left,
window_size_right,
deterministic,
philox_seed,
philox_offset);
} }
#endif #else
return mha_varlen_bwd_aot( return mha_varlen_bwd_aot(
dout, dout,
q, q,
@ -594,6 +671,7 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd
deterministic, deterministic,
philox_seed, philox_seed,
philox_offset); philox_offset);
#endif
} }
} // namespace pytorch_flash } // namespace pytorch_flash

View File

@ -1,3 +1,16 @@
macro(get_target_gpus_from_pytorch target_gpus)
set(gfx90a_key MI200)
set(gfx942_key MI300X)
set(gfx1100_key Navi31)
foreach(X IN LISTS PYTORCH_ROCM_ARCH)
set(key ${X})
string(APPEND key "_key")
string(APPEND target_gpus ${${key}})
string(APPEND target_gpus "|")
endforeach()
endmacro()
if(NOT __AOTRITON_INCLUDED) if(NOT __AOTRITON_INCLUDED)
set(__AOTRITON_INCLUDED TRUE) set(__AOTRITON_INCLUDED TRUE)
@ -9,22 +22,22 @@ if(NOT __AOTRITON_INCLUDED)
# Replaces .ci/docker/aotriton_version.txt # Replaces .ci/docker/aotriton_version.txt
# Note packages information may have versions skipped (due to no ABI breaks) # Note packages information may have versions skipped (due to no ABI breaks)
# But they must be listed from lower version to higher version # But they must be listed from lower version to higher version
set(__AOTRITON_VER "0.10b") set(__AOTRITON_VER "0.9.2b")
set(__AOTRITON_MANYLINUX_LIST set(__AOTRITON_MANYLINUX_LIST
"manylinux_2_28" # rocm6.2
"manylinux_2_28" # rocm6.3 "manylinux_2_28" # rocm6.3
"manylinux_2_28" # rocm6.4 "manylinux_2_28" # rocm6.4
"manylinux_2_28" # rocm7.0
) )
set(__AOTRITON_ROCM_LIST set(__AOTRITON_ROCM_LIST
"rocm6.2"
"rocm6.3" "rocm6.3"
"rocm6.4" "rocm6.4"
"rocm7.0"
) )
set(__AOTRITON_CI_COMMIT "6fca155f4deeb8d9529326f7b69f350aeeb93477") set(__AOTRITON_CI_COMMIT "b388d223d8c7213545603e00f6f3148c54d1f525")
set(__AOTRITON_SHA256_LIST set(__AOTRITON_SHA256_LIST
"861cd9f7479eec943933c27cb86920247e5b5dd139bc7c1376c81808abb7d7fe" # rocm6.3 "08d84f96f4c984179f80f517c0431c7511ee26bb0ce9bd05a827573ddd78cc79" # rocm6.2
"acea7d811a2d3bbe718b6e07fc2a9f739e49eecd60b4b6a36fcb3fe8edf85d78" # rocm6.4 "9094d59717e7e6eace9126ca100dd0e86510f07fc6c3a349569fc4e2d9056604" # rocm6.3
"7e29c325d5bd33ba896ddb106f5d4fc7d715274dca7fe937f724fffa82017838" # rocm7.0 "41190202c2736d5ff75b13a3abc0fb52ebfbb67226cf85dc3de7699c7000db44" # rocm6.4
) )
set(__AOTRITON_Z "gz") set(__AOTRITON_Z "gz")
@ -37,13 +50,17 @@ if(NOT __AOTRITON_INCLUDED)
set(__AOTRITON_INSTALL_DIR "$ENV{AOTRITON_INSTALLED_PREFIX}") set(__AOTRITON_INSTALL_DIR "$ENV{AOTRITON_INSTALLED_PREFIX}")
message(STATUS "Using Preinstalled AOTriton at ${__AOTRITON_INSTALL_DIR}") message(STATUS "Using Preinstalled AOTriton at ${__AOTRITON_INSTALL_DIR}")
elseif(DEFINED ENV{AOTRITON_INSTALL_FROM_SOURCE}) elseif(DEFINED ENV{AOTRITON_INSTALL_FROM_SOURCE})
set(target_gpus "")
get_target_gpus_from_pytorch(target_gpus)
ExternalProject_Add(aotriton_external ExternalProject_Add(aotriton_external
GIT_REPOSITORY https://github.com/ROCm/aotriton.git GIT_REPOSITORY https://github.com/ROCm/aotriton.git
GIT_TAG ${__AOTRITON_CI_COMMIT} GIT_TAG ${__AOTRITON_CI_COMMIT}
PREFIX ${__AOTRITON_EXTERN_PREFIX} PREFIX ${__AOTRITON_EXTERN_PREFIX}
INSTALL_DIR ${__AOTRITON_INSTALL_DIR} INSTALL_DIR ${__AOTRITON_INSTALL_DIR}
LIST_SEPARATOR |
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR} CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR}
-DAOTRITON_TARGET_ARCH:STRING=${PYTORCH_ROCM_ARCH} -DTARGET_GPUS:STRING=${target_gpus}
-DAOTRITON_COMPRESS_KERNEL=ON
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
-DAOTRITON_NO_PYTHON=ON -DAOTRITON_NO_PYTHON=ON
-DAOTRITON_NO_SHARED=OFF -DAOTRITON_NO_SHARED=OFF

View File

@ -3294,7 +3294,7 @@ class TestSDPACudaOnly(NNTestCase):
fudge_factors['grad_key'] = 70.0 fudge_factors['grad_key'] = 70.0
if seq_len_k >= 2048: if seq_len_k >= 2048:
fudge_factors['grad_key'] = 160.0 fudge_factors['grad_key'] = 160.0
fudge_factors['grad_query'] = 670.0 fudge_factors['grad_query'] = 650.0
if dtype == torch.float32: if dtype == torch.float32:
fudge_factors['grad_key'] = 90.0 fudge_factors['grad_key'] = 90.0
@ -3415,7 +3415,7 @@ class TestSDPACudaOnly(NNTestCase):
fudge_factors['grad_key'] = 70.0 fudge_factors['grad_key'] = 70.0
if seq_len_k >= 2048: if seq_len_k >= 2048:
fudge_factors['grad_key'] = 160.0 fudge_factors['grad_key'] = 160.0
fudge_factors['grad_query'] = 670.0 # gfx90a fudge_factors['grad_query'] = 650.0
if dtype == torch.float32: if dtype == torch.float32:
fudge_factors['grad_key'] = 90.0 fudge_factors['grad_key'] = 90.0

View File

@ -57,9 +57,9 @@ def CDNA2OrLater():
def evaluate_platform_supports_flash_attention(): def evaluate_platform_supports_flash_attention():
if TEST_WITH_ROCM: if TEST_WITH_ROCM:
arch_list = ["gfx90a", "gfx942", "gfx1100", "gfx1201", "gfx950"] arch_list = ["gfx90a", "gfx942", "gfx1100"]
if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0": if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0":
arch_list += ["gfx1101", "gfx1150", "gfx1151", "gfx1200"] arch_list += ["gfx1201", "gfx950"]
return evaluate_gfx_arch_within(arch_list) return evaluate_gfx_arch_within(arch_list)
if TEST_CUDA: if TEST_CUDA:
return not IS_WINDOWS and SM80OrLater return not IS_WINDOWS and SM80OrLater
@ -67,9 +67,9 @@ def evaluate_platform_supports_flash_attention():
def evaluate_platform_supports_efficient_attention(): def evaluate_platform_supports_efficient_attention():
if TEST_WITH_ROCM: if TEST_WITH_ROCM:
arch_list = ["gfx90a", "gfx942", "gfx1100", "gfx1201", "gfx950"] arch_list = ["gfx90a", "gfx942", "gfx1100"]
if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0": if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0":
arch_list += ["gfx1101", "gfx1150", "gfx1151", "gfx1200"] arch_list += ["gfx1201", "gfx950"]
return evaluate_gfx_arch_within(arch_list) return evaluate_gfx_arch_within(arch_list)
if TEST_CUDA: if TEST_CUDA:
return True return True