From 1036f6d114bc22a9b4cf620cf7f8364ea2fd7a60 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 20 Jun 2025 15:35:25 +0000 Subject: [PATCH] Revert "[ROCm] Bump AOTriton to 0.10b (#156290)" This reverts commit 34d8e64ef64d88324092a2028884c54c13e086b3. Reverted https://github.com/pytorch/pytorch/pull/156290 on behalf of https://github.com/atalman due to failing multiple internal tests ([comment](https://github.com/pytorch/pytorch/pull/156290#issuecomment-2992072727)) --- .../native/transformers/cuda/attention.cu | 16 +- .../transformers/cuda/attention_backward.cu | 12 - .../hip/flash_attn/aot/mha_all_aot.hip | 384 +++++------------- .../transformers/hip/flash_attn/flash_api.h | 152 +++++-- cmake/External/aotriton.cmake | 33 +- test/test_transformers.py | 4 +- torch/testing/_internal/common_cuda.py | 8 +- 7 files changed, 241 insertions(+), 368 deletions(-) diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index 80049aa9a83..125b95de7a3 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -1113,10 +1113,8 @@ _flash_attention_forward( std::optional alibi_slopes = _alibi_slopes; const float softcap = 0.0; -#ifndef USE_ROCM // ROCM backend accepts std::optional for window_size_left/right directly. - const int non_null_window_left = window_size_left.value_or(-1); - const int non_null_window_right = window_size_right.value_or(-1); -#endif + const int non_null_window_left = window_size_left.has_value() ? window_size_left.value() : -1; + const int non_null_window_right = window_size_right.has_value() ? window_size_right.value() : -1; // We are going to have two paths: // 1. The standard MHA path for dense tensors @@ -1153,13 +1151,8 @@ _flash_attention_forward( softmax_scale, false /*zero_tensors*/, is_causal, -#ifdef USE_ROCM - window_size_left, - window_size_right, -#else non_null_window_left, non_null_window_right, -#endif softcap, return_debug_mask, std::nullopt /*gen_*/); @@ -1182,13 +1175,8 @@ _flash_attention_forward( dropout_p, softmax_scale, is_causal, -#ifdef USE_ROCM - window_size_left, - window_size_right, -#else non_null_window_left, non_null_window_right, -#endif softcap, return_debug_mask, /*return_softmax (this is used for testing)*/ std::nullopt); diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu index 56715b4a038..5552807aa0e 100644 --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -87,10 +87,8 @@ std::tuple _flash_attention_backward( auto contiguous_grad_out = grad_out.contiguous(); auto contiguous_out = out.contiguous(); -#ifndef USE_ROCM // ROCM backend accepts std::optional for window_size_left/right directly. const int non_null_window_left = window_size_left.has_value() ? window_size_left.value() : -1; const int non_null_window_right = window_size_right.has_value() ? window_size_right.value() : -1; -#endif std::optional dq{std::nullopt}; std::optional dk{std::nullopt}; @@ -138,13 +136,8 @@ std::tuple _flash_attention_backward( softmax_scale, false /*zero_tensors*/, is_causal, -#ifdef USE_ROCM - window_size_left, - window_size_right, -#else non_null_window_left, non_null_window_right, -#endif softcap, determinisitic, philox_seed, @@ -166,13 +159,8 @@ std::tuple _flash_attention_backward( dropout_p, softmax_scale, is_causal, -#ifdef USE_ROCM - window_size_left, - window_size_right, -#else non_null_window_left, non_null_window_right, -#endif softcap, determinisitic, philox_seed, diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip index 7bed8bef1b9..92a3dabbd7d 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip @@ -64,14 +64,8 @@ #include #include -#if AOTRITON_VERSION_MINOR < 9 -#error "This adaptor code is only tested with AOTriton >= 0.9" -#endif - -#if (AOTRITON_VERSION_MAJOR * 100 + AOTRITON_VERSION_MINOR) >= 10 -#define V3_API 1 -#else -#define V3_API 0 +#if AOTRITON_VERSION_MINOR != 9 +#error "This adaptor code is only tested with AOTriton 0.9.x" #endif namespace pytorch_flash { @@ -87,38 +81,6 @@ void check_gpu_arch(hipStream_t stream) { } } -std::tuple -calculate_swa(std::optional window_size_left, - std::optional window_size_right, - int max_seqlen_q, - int max_seqlen_k, - bool is_causal) { -#if V3_API // SWA is exposed through V3 API - bool needs_swa = false; - using aotriton::v3::flash::WindowValue; - // Default values when std::optional window_size_left/right have no value - int window_left = max_seqlen_q; - int window_right = max_seqlen_k; - if (is_causal) { - window_left = WindowValue::TopLeftAligned; - window_right = WindowValue::TopLeftAligned; - } - if (window_size_left.has_value() || window_size_right.has_value()) { - needs_swa = true; - window_left = window_size_left.value_or(window_left); - window_right = window_size_right.value_or(window_right); - } - return std::make_tuple(needs_swa, window_left, window_right); -#else - if (window_size_left.has_value() || window_size_right.has_value()) { - TORCH_WARN_ONCE("Current AOTriton does not support sliding window attention (SWA)." - " Both window_size_left and window_size_right will be ignored." - " Re-compile PyTorch with AOTriton >= 0.10b to enable SWA support."); - } - return std::make_tuple(false, 0, 0); -#endif -} - // We want to checkpoint and save the RNG state for backward if dropout // We get the default generator and return the seed and offset which will // be used in the backward function @@ -165,8 +127,8 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x const float p_dropout, const float softmax_scale, bool is_causal, - std::optional window_size_left, - std::optional window_size_right, + int window_size_left, + int window_size_right, const bool return_softmax, const std::optional& gen_) { auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); @@ -199,6 +161,7 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case + if (is_causal) { window_size_right = 0; } CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og); @@ -249,19 +212,6 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x atomic_counter = at::zeros({1}, opts.dtype(at::kInt)); } - auto [needs_swa, window_left, window_right] = calculate_swa(window_size_left, - window_size_right, - seqlen_q, - seqlen_k, - is_causal); -#if V3_API - const bool uses_swa = needs_swa; -#else - // When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be - // optimized out (hopefully). - constexpr bool uses_swa = false; -#endif - hipError_t err; // TODO: Error handling using aotriton::v2::flash::attn_fwd; using sdp::aotriton_adapter::mk_aotensor; @@ -276,54 +226,23 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x auto seed_output = mk_philoxtensor(use_philox_state ? seed_t.data_ptr() : nullptr); auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr() : nullptr); auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr() : nullptr); - if (uses_swa) { -#if V3_API - using aotriton::v3::flash::CausalType; - using aotriton::v3::flash::VarlenType; - aotriton::v3::flash::attn_fwd_params params; - params.Q = mk_aotensor(q_t, "q"); - params.K = mk_aotensor(k_t, "k"); - params.V = mk_aotensor(v_t, "v"); - params.Sm_scale = softmax_scale; - params.L = mk_aotensor<2>(M, "M"); - params.Out = mk_aotensor(output_t, "Out"); - params.Max_seqlen_q = seqlen_q; // Unused if cu_seqlens_q is empty - params.Max_seqlen_k = seqlen_k; // Unused if cu_seqlens_k is empty - params.dropout_p = p_dropout; - params.philox_seed_ptr = seed; - params.philox_offset1 = offset1; - params.philox_offset2 = offset2; - params.philox_seed_output = seed_output; - params.philox_offset_output = offset_output; - params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax"); - params.persistent_atomic_counter = persistent_counter; - params.causal_type = CausalType::WindowedAttention; - params.varlen_type = VarlenType::None; - params.window_left = window_left; - params.window_right = window_right; - err = aotriton::v3::flash::attn_fwd(params, - aotriton::v3::flash::attn_fwd_params::kVersion, - stream); -#endif - } else { - err = attn_fwd(mk_aotensor(q_t, "q"), - mk_aotensor(k_t, "k"), - mk_aotensor(v_t, "v"), - empty_bias, - softmax_scale, - mk_aotensor<2>(M, "M"), - mk_aotensor(output_t, "Out"), - p_dropout, - seed, - offset1, - offset2, - seed_output, - offset_output, - mk_aotensor(softmax_fa_t, "encoded_softmax"), - is_causal, - persistent_counter, - stream); - } + err = attn_fwd(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + empty_bias, + softmax_scale, + mk_aotensor<2>(M, "M"), + mk_aotensor(output_t, "Out"), + p_dropout, + seed, + offset1, + offset2, + seed_output, + offset_output, + mk_aotensor(softmax_fa_t, "encoded_softmax"), + is_causal, + persistent_counter, + stream); return {out, q_padded, k_padded, v_padded, M.view({batch_size, num_heads, seqlen_q}), seed_t, offset_t, softmax_fa_t}; } @@ -344,8 +263,8 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot const float softmax_scale, const bool zero_tensors, bool is_causal, - std::optional window_size_left, - std::optional window_size_right, + int window_size_left, + int window_size_right, const bool return_softmax, const std::optional& gen_) { TORCH_CHECK(!seqused_k.has_value(), "[ROCm] mha_varlen_fwd: seqused_k must be nullopt"); @@ -393,6 +312,13 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot TORCH_CHECK(head_size_og <= 512, "FlashAttention on ROCm forward only supports head dimension at most 512"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + if (window_size_left >= max_seqlen_k) { + window_size_left = -1; + } + if (window_size_right >= max_seqlen_k) { + window_size_right = -1; + } + CHECK_SHAPE(temp_q, total_q, num_heads, head_size_og); const int total_k = k.size(0); CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); @@ -442,19 +368,6 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot } } - auto [needs_swa, window_left, window_right] = calculate_swa(window_size_left, - window_size_right, - max_seqlen_q, - max_seqlen_k, - is_causal); -#if V3_API - const bool uses_swa = needs_swa; -#else - // When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be - // optimized out (hopefully). - constexpr bool uses_swa = false; -#endif - auto [seed_t, offset_t, philox_state, use_philox_state] = prepare_philox_arguments(p_dropout, batch_size * num_heads * 32); @@ -477,58 +390,27 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr()) : nullscalar; auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr()) : nullscalar; auto persistent_counter = is_causal ? mk_philoxtensor(atomic_counter.data_ptr()) : nullscalar; - if (uses_swa) { - using aotriton::v3::flash::CausalType; - using aotriton::v3::flash::VarlenType; - aotriton::v3::flash::attn_fwd_params params; - params.Q = mk_aotensor(q_padded, "q"); - params.K = mk_aotensor(k_padded, "k"); - params.V = mk_aotensor(v_padded, "v"); - params.Sm_scale = softmax_scale; - params.L = mk_aotensor<2>(M, "M"); - params.Out = mk_aotensor(out_padded, "Out"); - params.cu_seqlens_q = mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"); - params.cu_seqlens_k = mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"); - params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty - params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty - params.dropout_p = p_dropout; - params.philox_seed_ptr = seed; - params.philox_offset1 = offset1; - params.philox_offset2 = offset2; - params.philox_seed_output = seed_output; - params.philox_offset_output = offset_output; - params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax"); - params.persistent_atomic_counter = persistent_counter; - params.causal_type = CausalType::WindowedAttention; - params.varlen_type = VarlenType::CompactVarlen; - params.window_left = window_left; - params.window_right = window_right; - err = aotriton::v3::flash::attn_fwd(params, - aotriton::v3::flash::attn_fwd_params::kVersion, - stream); - } else { - err = attn_fwd_compact_varlen(mk_aotensor(q_padded, "q"), - mk_aotensor(k_padded, "k"), - mk_aotensor(v_padded, "v"), - empty_bias, - mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"), - mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"), - max_seqlen_q, - max_seqlen_k, - softmax_scale, - mk_aotensor<2>(M, "M"), - mk_aotensor(out_padded, "Out"), - p_dropout, - seed, - offset1, - offset2, - seed_output, - offset_output, - mk_aotensor(softmax_fa_t, "encoded_softmax"), - is_causal, - persistent_counter, - stream); - } + err = attn_fwd_compact_varlen(mk_aotensor(q_padded, "q"), + mk_aotensor(k_padded, "k"), + mk_aotensor(v_padded, "v"), + empty_bias, + mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"), + mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"), + max_seqlen_q, + max_seqlen_k, + softmax_scale, + mk_aotensor<2>(M, "M"), + mk_aotensor(out_padded, "Out"), + p_dropout, + seed, + offset1, + offset2, + seed_output, + offset_output, + mk_aotensor(softmax_fa_t, "encoded_softmax"), + is_causal, + persistent_counter, + stream); } else { // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. out.zero_(); @@ -552,8 +434,8 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea const float p_dropout, // probability to drop const float softmax_scale, const bool is_causal, - std::optional window_size_left, - std::optional window_size_right, + int window_size_left, + int window_size_right, const bool deterministic, const at::Tensor& philox_seed, const at::Tensor& philox_offset) { @@ -642,19 +524,6 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea dv = at::empty_like(k); } - auto [needs_swa, window_left, window_right] = calculate_swa(window_size_left, - window_size_right, - seqlen_q, - seqlen_k, - is_causal); -#if V3_API - const bool uses_swa = needs_swa; -#else - // When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be - // optimized out (hopefully). - constexpr bool uses_swa = false; -#endif - auto opts = q.options(); auto softmax_d = at::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); @@ -672,40 +541,10 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea int d_head = head_size_og; bool use_fused_bwd = d_head <= 192 && d_head * seqlen_q < 64 * 512; hipError_t err; // TODO: Error handling - using sdp::aotriton_adapter::mk_aotensor; - using sdp::aotriton_adapter::mk_aoscalartensor; - if (uses_swa) { - // Fused BWD does not support SWA - at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous(); - using aotriton::v3::flash::CausalType; - using aotriton::v3::flash::VarlenType; - aotriton::v3::flash::attn_bwd_params params; - params.Q = mk_aotensor(q_t, "q"); - params.K = mk_aotensor(k_t, "k"); - params.V = mk_aotensor(v_t, "v"); - params.Sm_scale = softmax_scale; - params.Out = mk_aotensor(out_t, "out"); - params.DO = mk_aotensor(dout_t, "dout"); - params.DK = mk_aotensor(dq_t, "dq"); - params.DV = mk_aotensor(dk_t, "dk"); - params.DQ = mk_aotensor(dv_t, "dv"); - params.L = mk_aotensor<2>(softmax_lse_cont, "L"); - params.D = mk_aotensor<2>(delta, "delta"); - params.Max_seqlen_q = seqlen_q; // Unused if cu_seqlens_q is empty - params.Max_seqlen_k = seqlen_k; // Unused if cu_seqlens_k is empty - params.dropout_p = p_dropout; - params.philox_seed_ptr = mk_aoscalartensor(philox_seed); - params.philox_offset1 = mk_aoscalartensor(philox_offset); - params.philox_offset2 = 0; - params.causal_type = CausalType::WindowedAttention; - params.varlen_type = VarlenType::None; - params.window_left = window_left; - params.window_right = window_right; - err = aotriton::v3::flash::attn_bwd(params, - aotriton::v3::flash::attn_bwd_params::kVersion, - stream); - } else if (use_fused_bwd) { + if (use_fused_bwd) { using aotriton::v2::flash::attn_bwd_fused; + using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::mk_aoscalartensor; using sdp::aotriton_adapter::cast_dtype; aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); err = attn_bwd_fused(mk_aotensor(q_t, "q"), @@ -729,6 +568,8 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea } else { at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous(); using aotriton::v2::flash::attn_bwd; + using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::mk_aoscalartensor; using sdp::aotriton_adapter::cast_dtype; aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); err = attn_bwd(mk_aotensor(q_t, "q"), @@ -774,14 +615,17 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size const float softmax_scale, const bool zero_tensors, const bool is_causal, - std::optional window_size_left, - std::optional window_size_right, + int window_size_left, + int window_size_right, const bool deterministic, const at::Tensor& philox_seed, const at::Tensor& philox_offset) { TORCH_CHECK(!alibi_slopes_.has_value(), "[ROCm] mha_varlen_fwd: alibi_slopes_ must be nullopt"); + if (is_causal) { + window_size_right = 0; + } // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; @@ -825,6 +669,9 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size TORCH_CHECK(head_size <= 512, "FlashAttention on ROCm backward only supports head dimension at most 512"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + if (window_size_left >= max_seqlen_k) { window_size_left = -1; } + if (window_size_right >= max_seqlen_k) { window_size_right = -1; } + CHECK_SHAPE(q, total_q, num_heads, head_size); CHECK_SHAPE(k, total_k, num_heads_k, head_size); CHECK_SHAPE(v, total_k, num_heads_k, head_size); @@ -887,19 +734,6 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size softmax_d.zero_(); } - auto [needs_swa, window_left, window_right] = calculate_swa(window_size_left, - window_size_right, - max_seqlen_q, - max_seqlen_k, - is_causal); -#if V3_API - const bool uses_swa = needs_swa; -#else - // When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be - // optimized out (hopefully). - constexpr bool uses_swa = false; -#endif - at::PhiloxCudaState philox_args; if (is_dropout) { if (at::cuda::currentStreamCaptureStatus() == @@ -913,66 +747,34 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size } if (max_seqlen_q > 0) { hipError_t err; // TODO: Error handling + using aotriton::v2::flash::attn_bwd_compact_varlen; using sdp::aotriton_adapter::mk_aotensor; using sdp::aotriton_adapter::mk_aoscalartensor; - if (uses_swa) { - using aotriton::v3::flash::CausalType; - using aotriton::v3::flash::VarlenType; - aotriton::v3::flash::attn_bwd_params params; - params.Q = mk_aotensor(q_padded, "q"); - params.K = mk_aotensor(k_padded, "k"); - params.V = mk_aotensor(v_padded, "v"); - params.Sm_scale = softmax_scale; - params.Out = mk_aotensor(out_t, "out"); - params.DO = mk_aotensor(dout_t, "dout"); - params.DK = mk_aotensor(dq_padded, "dq"); - params.DV = mk_aotensor(dk_padded, "dk"); - params.DQ = mk_aotensor(dv_padded, "dv"); - params.L = mk_aotensor<2>(softmax_lse_cont, "L"); - params.D = mk_aotensor<2>(delta, "delta"); - params.cu_seqlens_q = mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"); - params.cu_seqlens_k = mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"); - params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty - params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty - params.dropout_p = p_dropout; - params.philox_seed_ptr = mk_aoscalartensor(philox_seed); - params.philox_offset1 = mk_aoscalartensor(philox_offset); - params.philox_offset2 = 0; - params.causal_type = CausalType::WindowedAttention; - params.varlen_type = VarlenType::CompactVarlen; - params.window_left = window_left; - params.window_right = window_right; - err = aotriton::v3::flash::attn_bwd(params, - aotriton::v3::flash::attn_bwd_params::kVersion, - stream); - } else { - using aotriton::v2::flash::attn_bwd_compact_varlen; - using sdp::aotriton_adapter::cast_dtype; - aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); - err = attn_bwd_compact_varlen(mk_aotensor(q_padded, "q"), - mk_aotensor(k_padded, "k"), - mk_aotensor(v_padded, "v"), - mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"), - mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"), - max_seqlen_q, - max_seqlen_k, - empty_bias, - softmax_scale, - mk_aotensor(out_t, "out"), - mk_aotensor(dout_t, "dout"), - mk_aotensor(dq_padded, "dq"), - mk_aotensor(dk_padded, "dk"), - mk_aotensor(dv_padded, "dv"), - empty_bias, - mk_aotensor<2>(softmax_lse_cont, "L"), - mk_aotensor<2>(delta, "delta"), - p_dropout, - mk_aoscalartensor(philox_seed), - mk_aoscalartensor(philox_offset), - 0, - is_causal, - stream); - } + using sdp::aotriton_adapter::cast_dtype; + aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); + err = attn_bwd_compact_varlen(mk_aotensor(q_padded, "q"), + mk_aotensor(k_padded, "k"), + mk_aotensor(v_padded, "v"), + mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"), + mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"), + max_seqlen_q, + max_seqlen_k, + empty_bias, + softmax_scale, + mk_aotensor(out_t, "out"), + mk_aotensor(dout_t, "dout"), + mk_aotensor(dq_padded, "dq"), + mk_aotensor(dk_padded, "dk"), + mk_aotensor(dv_padded, "dv"), + empty_bias, + mk_aotensor<2>(softmax_lse_cont, "L"), + mk_aotensor<2>(delta, "delta"), + p_dropout, + mk_aoscalartensor(philox_seed), + mk_aoscalartensor(philox_offset), + 0, + is_causal, + stream); } else { // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. dq.zero_(); diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h index 17298aae948..ead742a1efd 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h @@ -51,8 +51,8 @@ mha_fwd_aot( const float p_dropout, const float softmax_scale, bool is_causal, - std::optional window_size_left, - std::optional window_size_right, + int window_size_left, + int window_size_right, const bool return_softmax, const std::optional& gen_); @@ -87,8 +87,8 @@ mha_varlen_fwd_aot( const float softmax_scale, const bool zero_tensors, bool is_causal, - std::optional window_size_left, - std::optional window_size_right, + int window_size_left, + int window_size_right, const bool return_softmax, const std::optional& gen_); @@ -110,8 +110,8 @@ std::tuple mha_bwd_aot( const float p_dropout, // probability to drop const float softmax_scale, const bool is_causal, - std::optional window_size_left, - std::optional window_size_right, + int window_size_left, + int window_size_right, const bool deterministic, const at::Tensor& philox_seed, const at::Tensor& philox_offset); @@ -141,8 +141,8 @@ std::tuple mha_varlen_bwd_aot( const float softmax_scale, const bool zero_tensors, const bool is_causal, - std::optional window_size_left, - std::optional window_size_right, + int window_size_left, + int window_size_right, const bool deterministic, const at::Tensor& philox_seed, const at::Tensor& philox_offset); @@ -290,16 +290,14 @@ mha_fwd( const float p_dropout, const float softmax_scale, bool is_causal, - std::optional window_size_left, - std::optional window_size_right, + int window_size_left, + int window_size_right, const float softcap, const bool return_softmax, std::optional gen_) { #if defined(USE_CK_FLASH_ATTENTION) if (at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { - const int non_null_window_left = window_size_left.value_or(-1); - const int non_null_window_right = window_size_right.value_or(-1); std::optional dummy_attn_bias = std::nullopt; return mha_fwd_ck( q, @@ -309,13 +307,27 @@ mha_fwd( p_dropout, softmax_scale, is_causal, - non_null_window_left, - non_null_window_right, + window_size_left, + window_size_right, return_softmax, gen_, dummy_attn_bias); // Not used in flash attention + } else { + return mha_fwd_aot( + q, + k, + v, + out_, + alibi_slopes_, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + return_softmax, + gen_); } -#endif +#else return mha_fwd_aot( q, k, @@ -329,6 +341,7 @@ mha_fwd( window_size_right, return_softmax, gen_); +#endif } inline std::tuple< @@ -363,8 +376,8 @@ mha_varlen_fwd( const float softmax_scale, const bool zero_tensors, bool is_causal, - std::optional window_size_left, - std::optional window_size_right, + int window_size_left, + int window_size_right, const float softcap, const bool return_softmax, std::optional gen_) { @@ -372,8 +385,6 @@ mha_varlen_fwd( if (at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { std::optional dummy_attn_bias = std::nullopt; - const int non_null_window_left = window_size_left.value_or(-1); - const int non_null_window_right = window_size_right.value_or(-1); return mha_varlen_fwd_ck( q, k, @@ -388,13 +399,34 @@ mha_varlen_fwd( softmax_scale, zero_tensors, is_causal, - non_null_window_left, - non_null_window_right, + window_size_left, + window_size_right, return_softmax, gen_, dummy_attn_bias); // Not used in flash attention + } else { + return mha_varlen_fwd_aot( + q, + k, + v, + out_, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + block_table_, + alibi_slopes_, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + window_size_left, + window_size_right, + return_softmax, + gen_); } -#endif +#else return mha_varlen_fwd_aot( q, k, @@ -415,6 +447,7 @@ mha_varlen_fwd( window_size_right, return_softmax, gen_); +#endif } inline std::tuple mha_bwd( @@ -435,18 +468,16 @@ inline std::tuple mha_bwd( const float p_dropout, // probability to drop const float softmax_scale, const bool is_causal, - std::optional window_size_left, - std::optional window_size_right, + int window_size_left, + int window_size_right, const float softcap, const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset) { +#if defined(USE_CK_FLASH_ATTENTION) if (at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { -#if defined(USE_CK_FLASH_ATTENTION) std::optional non_null_dbias = std::nullopt; - const int non_null_window_left = window_size_left.value_or(-1); - const int non_null_window_right = window_size_right.value_or(-1); auto[dQuery, dKey, dValue, @@ -467,16 +498,38 @@ inline std::tuple mha_bwd( p_dropout, softmax_scale, is_causal, - non_null_window_left, - non_null_window_right, + window_size_left, + window_size_right, deterministic, philox_seed, philox_offset); // for FA return [dQ, dV, dK, dSoftmax] return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax)); + } else { + return mha_bwd_aot( + dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + alibi_slopes_, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + deterministic, + philox_seed, + philox_offset); + } #else + if(at::globalContext().getROCmFAPreferredBackend() == + at::ROCmFABackend::Ck) { TORCH_WARN_ONCE("Warning! You have opted to use CK flash attention backend in a build that was not compiled using USE_CK_FLASH_ATTENTION=1. Please set this variable and try again. Defaulting to use aotriton backend..."); -#endif } return mha_bwd_aot( dout, @@ -497,6 +550,7 @@ inline std::tuple mha_bwd( deterministic, philox_seed, philox_offset); +#endif } inline std::tuple mha_varlen_bwd( @@ -524,8 +578,8 @@ inline std::tuple mha_varlen_bwd const float softmax_scale, const bool zero_tensors, const bool is_causal, - std::optional window_size_left, - std::optional window_size_right, + int window_size_left, + int window_size_right, const float softcap, const bool deterministic, const at::Tensor philox_seed, @@ -534,8 +588,6 @@ inline std::tuple mha_varlen_bwd if (at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { std::optional non_null_dbias = std::nullopt; - const int non_null_window_left = window_size_left.value_or(-1); - const int non_null_window_right = window_size_right.value_or(-1); auto[dQuery, dKey, dValue, @@ -561,15 +613,40 @@ inline std::tuple mha_varlen_bwd softmax_scale, zero_tensors, is_causal, - non_null_window_left, - non_null_window_right, + window_size_left, + window_size_right, deterministic, philox_seed, philox_offset); // for FA return [dQ, dV, dK, dSoftmax] return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax)); + } else { + return mha_varlen_bwd_aot( + dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes_, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + window_size_left, + window_size_right, + deterministic, + philox_seed, + philox_offset); } -#endif +#else return mha_varlen_bwd_aot( dout, q, @@ -594,6 +671,7 @@ inline std::tuple mha_varlen_bwd deterministic, philox_seed, philox_offset); +#endif } } // namespace pytorch_flash diff --git a/cmake/External/aotriton.cmake b/cmake/External/aotriton.cmake index 8004b0f400a..9c1862f6b44 100644 --- a/cmake/External/aotriton.cmake +++ b/cmake/External/aotriton.cmake @@ -1,3 +1,16 @@ +macro(get_target_gpus_from_pytorch target_gpus) + set(gfx90a_key MI200) + set(gfx942_key MI300X) + set(gfx1100_key Navi31) + + foreach(X IN LISTS PYTORCH_ROCM_ARCH) + set(key ${X}) + string(APPEND key "_key") + string(APPEND target_gpus ${${key}}) + string(APPEND target_gpus "|") + endforeach() +endmacro() + if(NOT __AOTRITON_INCLUDED) set(__AOTRITON_INCLUDED TRUE) @@ -9,22 +22,22 @@ if(NOT __AOTRITON_INCLUDED) # Replaces .ci/docker/aotriton_version.txt # Note packages information may have versions skipped (due to no ABI breaks) # But they must be listed from lower version to higher version - set(__AOTRITON_VER "0.10b") + set(__AOTRITON_VER "0.9.2b") set(__AOTRITON_MANYLINUX_LIST + "manylinux_2_28" # rocm6.2 "manylinux_2_28" # rocm6.3 "manylinux_2_28" # rocm6.4 - "manylinux_2_28" # rocm7.0 ) set(__AOTRITON_ROCM_LIST + "rocm6.2" "rocm6.3" "rocm6.4" - "rocm7.0" ) - set(__AOTRITON_CI_COMMIT "6fca155f4deeb8d9529326f7b69f350aeeb93477") + set(__AOTRITON_CI_COMMIT "b388d223d8c7213545603e00f6f3148c54d1f525") set(__AOTRITON_SHA256_LIST - "861cd9f7479eec943933c27cb86920247e5b5dd139bc7c1376c81808abb7d7fe" # rocm6.3 - "acea7d811a2d3bbe718b6e07fc2a9f739e49eecd60b4b6a36fcb3fe8edf85d78" # rocm6.4 - "7e29c325d5bd33ba896ddb106f5d4fc7d715274dca7fe937f724fffa82017838" # rocm7.0 + "08d84f96f4c984179f80f517c0431c7511ee26bb0ce9bd05a827573ddd78cc79" # rocm6.2 + "9094d59717e7e6eace9126ca100dd0e86510f07fc6c3a349569fc4e2d9056604" # rocm6.3 + "41190202c2736d5ff75b13a3abc0fb52ebfbb67226cf85dc3de7699c7000db44" # rocm6.4 ) set(__AOTRITON_Z "gz") @@ -37,13 +50,17 @@ if(NOT __AOTRITON_INCLUDED) set(__AOTRITON_INSTALL_DIR "$ENV{AOTRITON_INSTALLED_PREFIX}") message(STATUS "Using Preinstalled AOTriton at ${__AOTRITON_INSTALL_DIR}") elseif(DEFINED ENV{AOTRITON_INSTALL_FROM_SOURCE}) + set(target_gpus "") + get_target_gpus_from_pytorch(target_gpus) ExternalProject_Add(aotriton_external GIT_REPOSITORY https://github.com/ROCm/aotriton.git GIT_TAG ${__AOTRITON_CI_COMMIT} PREFIX ${__AOTRITON_EXTERN_PREFIX} INSTALL_DIR ${__AOTRITON_INSTALL_DIR} + LIST_SEPARATOR | CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR} - -DAOTRITON_TARGET_ARCH:STRING=${PYTORCH_ROCM_ARCH} + -DTARGET_GPUS:STRING=${target_gpus} + -DAOTRITON_COMPRESS_KERNEL=ON -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DAOTRITON_NO_PYTHON=ON -DAOTRITON_NO_SHARED=OFF diff --git a/test/test_transformers.py b/test/test_transformers.py index 269ffe682ab..46f57bdec52 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -3294,7 +3294,7 @@ class TestSDPACudaOnly(NNTestCase): fudge_factors['grad_key'] = 70.0 if seq_len_k >= 2048: fudge_factors['grad_key'] = 160.0 - fudge_factors['grad_query'] = 670.0 + fudge_factors['grad_query'] = 650.0 if dtype == torch.float32: fudge_factors['grad_key'] = 90.0 @@ -3415,7 +3415,7 @@ class TestSDPACudaOnly(NNTestCase): fudge_factors['grad_key'] = 70.0 if seq_len_k >= 2048: fudge_factors['grad_key'] = 160.0 - fudge_factors['grad_query'] = 670.0 # gfx90a + fudge_factors['grad_query'] = 650.0 if dtype == torch.float32: fudge_factors['grad_key'] = 90.0 diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index a211851d671..82f486e97f6 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -57,9 +57,9 @@ def CDNA2OrLater(): def evaluate_platform_supports_flash_attention(): if TEST_WITH_ROCM: - arch_list = ["gfx90a", "gfx942", "gfx1100", "gfx1201", "gfx950"] + arch_list = ["gfx90a", "gfx942", "gfx1100"] if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0": - arch_list += ["gfx1101", "gfx1150", "gfx1151", "gfx1200"] + arch_list += ["gfx1201", "gfx950"] return evaluate_gfx_arch_within(arch_list) if TEST_CUDA: return not IS_WINDOWS and SM80OrLater @@ -67,9 +67,9 @@ def evaluate_platform_supports_flash_attention(): def evaluate_platform_supports_efficient_attention(): if TEST_WITH_ROCM: - arch_list = ["gfx90a", "gfx942", "gfx1100", "gfx1201", "gfx950"] + arch_list = ["gfx90a", "gfx942", "gfx1100"] if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0": - arch_list += ["gfx1101", "gfx1150", "gfx1151", "gfx1200"] + arch_list += ["gfx1201", "gfx950"] return evaluate_gfx_arch_within(arch_list) if TEST_CUDA: return True