From 3ecfe6be256c585bcadf4c845d7119545444a222 Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 26 Feb 2025 00:10:59 +0000 Subject: [PATCH] [Submodule] Turning flash-attention integration into 3rd party submod (#144120) (#146372) Summary: # Summary ### Sticky points Cuda-graph rng handling has changed / deviated from original implementation. We will be left with a dangling 'offset' val and confusing naming due to BC ## Dependencies - Flash PR: https://github.com/Dao-AILab/flash-attention/pull/1419 ### Other Points - The BC linter is complaining about losing generate.py and its functions which is not real BC surface cc albanD imported-using-ghimport Test Plan: Imported from OSS Building in dev `buck build @//mode/dev-nosan -c fbcode.nvcc_arch=h100a //caffe2:ATen-cu --show-full-output ` I and Nming the .so I do see that the flash symbols are correctly named: ``` 0000000001c3dfb0 t pytorch_flash::run_mha_bwd(pytorch_flash::Flash_bwd_params&, CUstream_st*)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()#7}::operator()() const 0000000001c36080 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#2}::operator()() const::{lambda()#1}::operator()() const::{lambda()#6}::operator()() const 0000000001c360e0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#2}::operator()() const::{lambda()#1}::operator()() const::{lambda()#7}::operator()() const 0000000001c35fc0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()#6}::operator()() const 0000000001c36020 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()#7}::operator()() const ``` Reviewed By: vkuzo Differential Revision: D68502879 Pulled By: drisspg Pull Request resolved: https://github.com/pytorch/pytorch/pull/146372 Approved by: https://github.com/jbschlosser --- CMakeLists.txt | 5 - aten/src/ATen/CMakeLists.txt | 9 +- aten/src/ATen/native/native_functions.yaml | 6 +- .../native/transformers/cuda/attention.cu | 10 +- .../transformers/cuda/attention_backward.cu | 7 +- .../transformers/cuda/flash_attn/alibi.h | 74 - .../transformers/cuda/flash_attn/block_info.h | 46 - .../transformers/cuda/flash_attn/dropout.h | 96 -- .../transformers/cuda/flash_attn/flash.h | 190 --- .../cuda/flash_attn/flash_api.cpp | 192 +-- .../transformers/cuda/flash_attn/flash_api.h | 9 +- .../cuda/flash_attn/flash_bwd_kernel.h | 827 ----------- .../flash_attn/flash_bwd_launch_template.h | 338 ----- .../flash_attn/flash_bwd_preprocess_kernel.h | 377 ----- .../cuda/flash_attn/flash_fwd_kernel.h | 1254 ----------------- .../flash_attn/flash_fwd_launch_template.h | 378 ----- .../cuda/flash_attn/kernel_traits.h | 347 ----- .../kernels/flash_bwd_hdim128_bf16_sm80.cu | 14 - .../kernels/flash_bwd_hdim128_fp16_sm80.cu | 14 - .../kernels/flash_bwd_hdim160_bf16_sm80.cu | 14 - .../kernels/flash_bwd_hdim160_fp16_sm80.cu | 14 - .../kernels/flash_bwd_hdim192_bf16_sm80.cu | 14 - .../kernels/flash_bwd_hdim192_fp16_sm80.cu | 14 - .../kernels/flash_bwd_hdim224_bf16_sm80.cu | 14 - .../kernels/flash_bwd_hdim224_fp16_sm80.cu | 14 - .../kernels/flash_bwd_hdim256_bf16_sm80.cu | 14 - .../kernels/flash_bwd_hdim256_fp16_sm80.cu | 14 - .../kernels/flash_bwd_hdim32_bf16_sm80.cu | 14 - .../kernels/flash_bwd_hdim32_fp16_sm80.cu | 14 - .../kernels/flash_bwd_hdim64_bf16_sm80.cu | 14 - .../kernels/flash_bwd_hdim64_fp16_sm80.cu | 14 - .../kernels/flash_bwd_hdim96_bf16_sm80.cu | 14 - .../kernels/flash_bwd_hdim96_fp16_sm80.cu | 14 - .../kernels/flash_fwd_hdim128_bf16_sm80.cu | 14 - .../kernels/flash_fwd_hdim128_fp16_sm80.cu | 14 - .../kernels/flash_fwd_hdim160_bf16_sm80.cu | 14 - .../kernels/flash_fwd_hdim160_fp16_sm80.cu | 14 - .../kernels/flash_fwd_hdim192_bf16_sm80.cu | 14 - .../kernels/flash_fwd_hdim192_fp16_sm80.cu | 14 - .../kernels/flash_fwd_hdim224_bf16_sm80.cu | 14 - .../kernels/flash_fwd_hdim224_fp16_sm80.cu | 14 - .../kernels/flash_fwd_hdim256_bf16_sm80.cu | 14 - .../kernels/flash_fwd_hdim256_fp16_sm80.cu | 14 - .../kernels/flash_fwd_hdim32_bf16_sm80.cu | 14 - .../kernels/flash_fwd_hdim32_fp16_sm80.cu | 14 - .../kernels/flash_fwd_hdim64_bf16_sm80.cu | 14 - .../kernels/flash_fwd_hdim64_fp16_sm80.cu | 14 - .../kernels/flash_fwd_hdim96_bf16_sm80.cu | 14 - .../kernels/flash_fwd_hdim96_fp16_sm80.cu | 14 - .../flash_fwd_split_hdim128_bf16_sm80.cu | 12 - .../flash_fwd_split_hdim128_fp16_sm80.cu | 12 - .../flash_fwd_split_hdim160_bf16_sm80.cu | 12 - .../flash_fwd_split_hdim160_fp16_sm80.cu | 12 - .../flash_fwd_split_hdim192_bf16_sm80.cu | 12 - .../flash_fwd_split_hdim192_fp16_sm80.cu | 12 - .../flash_fwd_split_hdim224_bf16_sm80.cu | 12 - .../flash_fwd_split_hdim224_fp16_sm80.cu | 12 - .../flash_fwd_split_hdim256_bf16_sm80.cu | 12 - .../flash_fwd_split_hdim256_fp16_sm80.cu | 12 - .../flash_fwd_split_hdim32_bf16_sm80.cu | 12 - .../flash_fwd_split_hdim32_fp16_sm80.cu | 12 - .../flash_fwd_split_hdim64_bf16_sm80.cu | 12 - .../flash_fwd_split_hdim64_fp16_sm80.cu | 12 - .../flash_fwd_split_hdim96_bf16_sm80.cu | 12 - .../flash_fwd_split_hdim96_fp16_sm80.cu | 12 - .../flash_attn/kernels/generate_kernels.py | 109 -- .../transformers/cuda/flash_attn/mask.h | 213 --- .../transformers/cuda/flash_attn/rotary.h | 152 -- .../transformers/cuda/flash_attn/softmax.h | 186 --- .../transformers/cuda/flash_attn/utils.h | 394 ------ .../transformers/hip/flash_attn/flash_api.h | 4 + caffe2/CMakeLists.txt | 19 +- tools/autograd/derivatives.yaml | 8 +- torch/_meta_registrations.py | 43 +- .../aoti_torch/generated/c_shim_cuda.h | 2 +- 75 files changed, 186 insertions(+), 5749 deletions(-) delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/alibi.h delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/block_info.h delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/dropout.h delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/flash.h delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_preprocess_kernel.h delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_kernel.h delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernel_traits.h delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim128_bf16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim128_fp16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim160_bf16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim160_fp16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim192_bf16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim192_fp16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim224_bf16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim224_fp16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim256_bf16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim256_fp16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim32_bf16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim32_fp16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim64_bf16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim64_fp16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_bf16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_fp16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim128_bf16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim128_fp16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim160_bf16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim160_fp16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim192_bf16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim192_fp16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim224_bf16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim224_fp16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim256_bf16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim256_fp16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim32_bf16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim32_fp16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim64_bf16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim64_fp16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim96_bf16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim96_fp16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim128_bf16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim128_fp16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim160_bf16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim160_fp16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim192_bf16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim192_fp16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim224_bf16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim224_fp16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim256_bf16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim256_fp16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim32_bf16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim32_fp16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim64_bf16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim64_fp16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim96_bf16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim96_fp16_sm80.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/mask.h delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/rotary.h delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/softmax.h delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/utils.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 92bedacfef4..824be518630 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -872,11 +872,6 @@ cmake_dependent_option( "USE_CUDA OR USE_ROCM;NOT MSVC" OFF) -# We are currenlty not using alibi attention for Flash So we disable this -# feature by default We dont currently document this feature because we don't -# Suspect users building from source will need this -add_definitions(-DFLASHATTENTION_DISABLE_ALIBI) - # CAVEAT: Again, Flash Attention2 will error while building for sm52 while Mem # Eff Attention won't cmake_dependent_option( diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 874b45688d5..34ce25c6e9f 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -164,9 +164,12 @@ file(GLOB native_quantized_cudnn_hip_cpp "native/quantized/cudnn/hip/*.cpp") file(GLOB native_utils_cpp "native/utils/*.cpp") # flash_attention sources -file(GLOB flash_attention_cuda_cu "native/transformers/cuda/flash_attn/*.cu") -file(GLOB flash_attention_cuda_kernels_cu "native/transformers/cuda/flash_attn/kernels/*.cu") -file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp") +file(GLOB flash_attention_cuda_kernels_cu ${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc/flash_attn/src/*.cu) +# Flash attention C++ sources +file(GLOB flash_attention_cuda_cpp + "${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc/flash_attn/src/*.cpp" + "native/transformers/cuda/flash_attn/flash_api.cpp" +) # flash_attention hip sources file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip") diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 931357b2426..73f675ea05e 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -14852,7 +14852,7 @@ MPS: _scaled_dot_product_attention_math_mps tags: nondeterministic_seeded -- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) +- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor rng_state, Tensor unused, Tensor debug_attn_mask) dispatch: CUDA: _scaled_dot_product_flash_attention_cuda NestedTensorCUDA: _scaled_dot_product_flash_attention_nestedtensor_cuda @@ -14909,13 +14909,13 @@ CUDA: _scaled_dot_product_cudnn_attention_backward_cuda tags: nondeterministic_seeded -- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) +- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor rng_state, Tensor unused, Tensor debug_attn_mask) variants: function dispatch: CUDA: _flash_attention_forward tags: nondeterministic_seeded -- func: _flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None) -> (Tensor, Tensor, Tensor) +- func: _flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor rng_state, Tensor unused, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None) -> (Tensor, Tensor, Tensor) device_check: NoCheck variants: function dispatch: diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index 7fe7ee7a1ba..0f742d275a1 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -70,6 +70,9 @@ #ifdef USE_FLASH_ATTENTION // FlashAttention Specific Imports #include +#if !defined(__HIP_PLATFORM_AMD__) +#include +#endif #endif #ifdef USE_MEM_EFF_ATTENTION #ifndef USE_ROCM @@ -916,6 +919,7 @@ _flash_attention_forward( std::optional seqused_k = _seqused_k; std::optional block_table = std::nullopt; // we are not using the block table yet std::optional alibi_slopes = _alibi_slopes; + const float softcap = 0.0; const int non_null_window_left = window_size_left.has_value() ? window_size_left.value() : -1; const int non_null_window_right = window_size_right.has_value() ? window_size_right.value() : -1; @@ -939,7 +943,7 @@ _flash_attention_forward( philox_seed, philox_offset, debug_attn_mask) = - pytorch_flash::mha_varlen_fwd( + FLASH_NAMESPACE::mha_varlen_fwd( query, key, value, @@ -957,6 +961,7 @@ _flash_attention_forward( is_causal, non_null_window_left, non_null_window_right, + softcap, return_debug_mask, std::nullopt /*gen_*/); } else { @@ -969,7 +974,7 @@ _flash_attention_forward( philox_seed, philox_offset, debug_attn_mask) = - pytorch_flash::mha_fwd( + FLASH_NAMESPACE::mha_fwd( query, key, value, @@ -980,6 +985,7 @@ _flash_attention_forward( is_causal, non_null_window_left, non_null_window_right, + softcap, return_debug_mask, /*return_softmax (this is used for testing)*/ std::nullopt); } diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu index 09799ff125d..4464d63f12c 100644 --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -94,6 +94,7 @@ std::tuple _flash_attention_backward( // Currently unused args: std::optional alibi_slopes{std::nullopt}; + const float softcap = 0.0; bool determinisitic{false}; auto& ctx = at::globalContext(); @@ -111,7 +112,7 @@ std::tuple _flash_attention_backward( // in order to determine whether we are using varlen or dense forward if (cumulative_sequence_length_q.defined()) { // Varlen forward - auto [dQuery, dKey, dValue, dSoftmax] = pytorch_flash::mha_varlen_bwd( + auto [dQuery, dKey, dValue, dSoftmax] = FLASH_NAMESPACE::mha_varlen_bwd( contiguous_grad_out, query, key, @@ -132,13 +133,14 @@ std::tuple _flash_attention_backward( is_causal, non_null_window_left, non_null_window_right, + softcap, determinisitic, philox_seed, philox_offset); return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue)); } else { // Dense forward - auto [dQuery, dKey, dValue, dSoftmax] = pytorch_flash::mha_bwd( + auto [dQuery, dKey, dValue, dSoftmax] = FLASH_NAMESPACE::mha_bwd( contiguous_grad_out, query, key, @@ -154,6 +156,7 @@ std::tuple _flash_attention_backward( is_causal, non_null_window_left, non_null_window_right, + softcap, determinisitic, philox_seed, philox_offset); diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/alibi.h b/aten/src/ATen/native/transformers/cuda/flash_attn/alibi.h deleted file mode 100644 index 311231432c7..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/alibi.h +++ /dev/null @@ -1,74 +0,0 @@ -#include - -#include - -#include -#include - -#include - -namespace pytorch_flash { - -using namespace cute; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Alibi { - - const float alibi_slope; - const int max_seqlen_k, max_seqlen_q; - - __forceinline__ __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q) - : alibi_slope(alibi_slope) - , max_seqlen_k(max_seqlen_k) - , max_seqlen_q(max_seqlen_q) { - }; - - - template - __forceinline__ __device__ void apply_alibi(Tensor &tensor, - const int col_idx_offset_, - const int row_idx_offset, - const int warp_row_stride) { - // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) - static_assert(Layout::rank == 2, "Only support 2D Tensor"); - const int lane_id = threadIdx.x % 32; - const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; - if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows - #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; - #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - #pragma unroll - for (int mi = 0; mi < size<0>(tensor); ++mi) { - tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; - } - } - } - } else { // Bias depends on both row_idx and col_idx - #pragma unroll - for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const int row_idx_base = row_idx_offset + mi * warp_row_stride; - #pragma unroll - for (int i = 0; i < size<0, 0>(tensor); ++i) { - const int row_idx = row_idx_base + i * 8; - #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; - #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); - } - } - } - } - } - } - -}; - -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/block_info.h b/aten/src/ATen/native/transformers/cuda/flash_attn/block_info.h deleted file mode 100644 index bbaf6978002..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/block_info.h +++ /dev/null @@ -1,46 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#pragma once - -namespace pytorch_flash { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct BlockInfo { - - template - __device__ BlockInfo(const Params ¶ms, const int bidb) - : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]) - , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]) - , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) - // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. - // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. - , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) - , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) - { - } - - template - __forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { - return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; - } - - template - __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { - return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; - } - - const int sum_s_q; - const int sum_s_k; - const int actual_seqlen_q; - // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0. - const int seqlen_k_cache; - const int actual_seqlen_k; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/dropout.h b/aten/src/ATen/native/transformers/cuda/flash_attn/dropout.h deleted file mode 100644 index a40815575ff..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/dropout.h +++ /dev/null @@ -1,96 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include -#include - -namespace pytorch_flash { - -using namespace cute; - -struct Dropout { - - const unsigned long long seed, offset; - const uint8_t p_dropout_in_uint8_t; - - __forceinline__ __device__ Dropout(const unsigned long long seed, const unsigned long long offset, - const uint8_t p_dropout_in_uint8_t, - const int bid, const int hid, const int tid, const int nheads) - : seed(seed) - , offset(offset + (bid * nheads + hid) * 32 + tid % 32) - , p_dropout_in_uint8_t(p_dropout_in_uint8_t) { - } - - template - __forceinline__ __device__ void apply_dropout(Tensor &tensor_, - int block_row_start, int block_col_start, int block_row_stride) { - // convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2) - Tensor tensor = make_tensor(tensor_.data(), pytorch_flash::convert_layout_acc_dropout(tensor_.layout())); - using T = typename Engine::value_type; - auto encode_dropout = [](bool keep, T val) { - return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0)); - }; - static_assert(decltype(size<2>(tensor))::value % 2 == 0); - const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t); - const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t); - // if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); } - #pragma unroll - for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) { - uint2 rowcol = make_uint2(block_row_start, block_col_start); - #pragma unroll - for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) { - // if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));} - uint4 random_uint4 = pytorch_flash::philox(seed, reinterpret_cast(rowcol), offset); - // if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);} - uint8_t (&rnd_8)[16] = reinterpret_cast(random_uint4); - // Special implementation for 16-bit types: we duplicate the threshold to the - // low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction - // to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000, - // and the high 16 bits will be either 0xffff or 0x0000, depending on whether - // the random value is less than the threshold. - // We then do a bit-wise AND between the mask and the original value (in 32-bit). - // We're exploiting the fact that floating point comparison is equivalent to integer - // comparison, since we're comparing unsigned integers whose top 8-bits are zero. - if (!encode_dropout_in_sign_bit - && (std::is_same_v || std::is_same_v)) { - uint16_t rnd_16[16]; - #pragma unroll - for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); } - uint32_t (&rnd_32)[8] = reinterpret_cast(rnd_16); - #pragma unroll - for (int j = 0; j < 2; j++) { - Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); - // if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); } - // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } - #pragma unroll - for (int i = 0; i < 4; i++) { - uint32_t mask; - asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t)); - tensor_uint32(i) &= mask; - } - // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } - } - } else { - #pragma unroll - for (int j = 0; j < 2; j++) { - #pragma unroll - for (int i = 0; i < 8; i++) { - tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j)); - } - Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); - // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } - } - } - // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w); - // // } - } - } - } - -}; - -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash.h b/aten/src/ATen/native/transformers/cuda/flash_attn/flash.h deleted file mode 100644 index 9ce14cf6489..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash.h +++ /dev/null @@ -1,190 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include - -#ifdef OLD_GENERATOR_PATH -#include -#else -#include -#endif - -#include // For at::cuda::philox::unpack -namespace pytorch_flash { -constexpr int TOTAL_DIM = 0; -constexpr int H_DIM = 1; -constexpr int D_DIM = 2; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Qkv_params { - using index_t = int64_t; - // The QKV matrices. - void *__restrict__ q_ptr; - void *__restrict__ k_ptr; - void *__restrict__ v_ptr; - - // The stride between rows of the Q, K and V matrices. - index_t q_batch_stride; - index_t k_batch_stride; - index_t v_batch_stride; - index_t q_row_stride; - index_t k_row_stride; - index_t v_row_stride; - index_t q_head_stride; - index_t k_head_stride; - index_t v_head_stride; - - // The number of heads. - int h, h_k; - // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be - // different from nheads (query). - int h_h_k_ratio; // precompute h / h_k, -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Flash_fwd_params : public Qkv_params { - - // The O matrix (output). - void * __restrict__ o_ptr; - void * __restrict__ oaccum_ptr; - - // The stride between rows of O. - index_t o_batch_stride; - index_t o_row_stride; - index_t o_head_stride; - - // The pointer to the P matrix. - void * __restrict__ p_ptr; - - // The pointer to the softmax sum. - void * __restrict__ softmax_lse_ptr; - void * __restrict__ softmax_lseaccum_ptr; - - // The dimensions. - int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; - - // The scaling factors for the kernel. - float scale_softmax; - float scale_softmax_log2; - - // array of length b+1 holding starting offset of each sequence. - int * __restrict__ cu_seqlens_q; - int * __restrict__ cu_seqlens_k; - - // If provided, the actual length of each k sequence. - int * __restrict__ seqused_k; - - int *__restrict__ blockmask; - - // The K_new and V_new matrices. - void * __restrict__ knew_ptr; - void * __restrict__ vnew_ptr; - - // The stride between rows of the Q, K and V matrices. - index_t knew_batch_stride; - index_t vnew_batch_stride; - index_t knew_row_stride; - index_t vnew_row_stride; - index_t knew_head_stride; - index_t vnew_head_stride; - - // The cos and sin matrices for rotary embedding. - void * __restrict__ rotary_cos_ptr; - void * __restrict__ rotary_sin_ptr; - - // The indices to index into the KV cache. - int * __restrict__ cache_batch_idx; - - // Paged KV cache - int * __restrict__ block_table; - index_t block_table_batch_stride; - int page_block_size; - - // The dropout probability (probability of keeping an activation). - float p_dropout; - // uint32_t p_dropout_in_uint; - // uint16_t p_dropout_in_uint16_t; - uint8_t p_dropout_in_uint8_t; - - // Scale factor of 1 / (1 - p_dropout). - float rp_dropout; - float scale_softmax_rp_dropout; - - // Local window size - int window_size_left, window_size_right; - - // Random state. - at::PhiloxCudaState philox_args; - int64_t * extragraph_offset; - int64_t * seed; - - bool is_bf16; - bool is_causal; - - // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. - // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. - bool is_seqlens_k_cumulative; - - bool is_rotary_interleaved; - - int num_splits; // For split-KV version - - void * __restrict__ alibi_slopes_ptr; - index_t alibi_slopes_batch_stride; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Flash_bwd_params : public Flash_fwd_params { - - // The dO and dQKV matrices. - void *__restrict__ do_ptr; - void *__restrict__ dq_ptr; - void *__restrict__ dk_ptr; - void *__restrict__ dv_ptr; - - // To accumulate dQ - void *__restrict__ dq_accum_ptr; - void *__restrict__ dk_accum_ptr; - void *__restrict__ dv_accum_ptr; - - // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q - // dimension void *__restrict__ dk_accum_ptr; void *__restrict__ - // dv_accum_ptr; - - // The stride between rows of the dO, dQ, dK and dV matrices. - // TD [2022-04-16]: We're using 32-bit indexing to save registers. - // The code probably won't work for arrays larger than 2GB. - index_t do_batch_stride; - index_t do_row_stride; - index_t do_head_stride; - index_t dq_batch_stride; - index_t dk_batch_stride; - index_t dv_batch_stride; - index_t dq_row_stride; - index_t dk_row_stride; - index_t dv_row_stride; - index_t dq_head_stride; - index_t dk_head_stride; - index_t dv_head_stride; - - // The pointer to the softmax d sum. - void *__restrict__ dsoftmax_sum; - - bool deterministic; - index_t dq_accum_split_stride; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); - -template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); - -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp index c8f59fe6a71..c3967b7296c 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp @@ -2,6 +2,7 @@ * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #include +#include #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include @@ -9,6 +10,7 @@ #ifdef USE_FLASH_ATTENTION + #include #include #include @@ -32,13 +34,16 @@ #include -#include + +#include +#include +#include #include -#include + #include -namespace pytorch_flash { +namespace FLASH_NAMESPACE { #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == at::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") @@ -70,7 +75,9 @@ void set_params_fprop(Flash_fwd_params ¶ms, float softmax_scale, int window_size_left, int window_size_right, - bool seqlenq_ngroups_swapped=false) { + const float softcap, + bool seqlenq_ngroups_swapped=false, + const bool unpadded_lse=false) { // Reset the parameters params = {}; @@ -126,8 +133,19 @@ void set_params_fprop(Flash_fwd_params ¶ms, params.d_rounded = d_rounded; // Set the different scale values. - params.scale_softmax = softmax_scale; - params.scale_softmax_log2 = softmax_scale * M_LOG2E; + #ifdef FLASHATTENTION_DISABLE_SOFTCAP + TORCH_CHECK(softcap <= 0.0, "This flash attention build does not support softcap."); + #endif + if (softcap > 0.0) { + params.softcap = softmax_scale / softcap; + params.scale_softmax = softcap; + params.scale_softmax_log2 = softcap * M_LOG2E; + } else{ + // Remove potential NaN + params.softcap = 0.0; + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + } // Set this to probability of keeping an element to simplify things. params.p_dropout = 1.f - p_dropout; @@ -162,6 +180,8 @@ void set_params_fprop(Flash_fwd_params ¶ms, #ifdef FLASHATTENTION_DISABLE_UNEVEN_K TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32."); #endif + params.unpadded_lse = unpadded_lse; + params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped; } void set_params_dgrad(Flash_bwd_params ¶ms, @@ -195,7 +215,9 @@ void set_params_dgrad(Flash_bwd_params ¶ms, float softmax_scale, int window_size_left, int window_size_right, - bool deterministic) { + const float softcap, + bool deterministic, + const bool unpadded_lse) { set_params_fprop(params, b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded, @@ -208,7 +230,10 @@ void set_params_dgrad(Flash_bwd_params ¶ms, p_dropout, softmax_scale, window_size_left, - window_size_right); + window_size_right, + softcap, + false, // seqlenq_ngroups_swapped + unpadded_lse); // Set the pointers and strides. params.do_ptr = dout.data_ptr(); @@ -244,11 +269,13 @@ void set_params_dgrad(Flash_bwd_params ¶ms, void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) { FP16_SWITCH(!params.is_bf16, [&] { HEADDIM_SWITCH(params.d, [&] { - if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 - run_mha_fwd_(params, stream); - } else { - run_mha_fwd_splitkv_dispatch(params, stream); - } + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 + run_mha_fwd_(params, stream); + } else { + run_mha_fwd_splitkv_dispatch(params, stream); + } + }); }); }); } @@ -357,6 +384,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head bool is_causal, int window_size_left, int window_size_right, + const float softcap, const bool return_softmax, std::optional gen_) { @@ -398,6 +426,8 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } + if (window_size_left >= seqlen_k) { window_size_left = -1; } if (window_size_right >= seqlen_k) { window_size_right = -1; } @@ -443,7 +473,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size = round_multiple(head_size_og, 8); - const int head_size_rounded = round_multiple(head_size, 32); + const int head_size_rounded = round_multiple(head_size, 32) < 224 ? round_multiple(head_size, 32) : 256; const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); @@ -478,7 +508,9 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head p_dropout, softmax_scale, window_size_left, - window_size_right); + window_size_right, + softcap + ); // Keep references to these tensors to extend their lifetime @@ -486,10 +518,9 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head head_size, seqlen_k, seqlen_q, head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts); - // We want to checkpoint and save the RNG state for backward if dropout - // We get the default generator and return the seed and offset which will - // be used in the backward function - at::Tensor seed_t, offset_t; + // See [Note] BC breaking change to flash seed/offset + auto rng_state = at::empty({2}, at::dtype(c10::kUInt64).device(at::kCUDA)); + auto _unused = at::empty({}, at::dtype(c10::kUInt64).device(at::kCUDA)); if (p_dropout > 0.0) { auto gen = at::get_generator_or_default(std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); // number of times random will be generated per thread, to offset philox counter in thc random @@ -499,26 +530,9 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); at::PhiloxCudaState philox_state = gen->philox_cuda_state(counter_offset); - if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) { - auto [seed, offset] = at::cuda::philox::unpack(philox_state); - seed_t = at::scalar_tensor(at::Scalar(static_cast(seed)), at::dtype(at::kLong)); - offset_t = at::scalar_tensor(at::Scalar(static_cast(offset)), at::dtype(at::kLong)); - } else { - seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - params.seed = seed_t.data_ptr(); - params.extragraph_offset = offset_t.data_ptr(); - } + rng_state = at::empty({2}, at::TensorOptions().dtype(c10::kUInt64).device(at::kCUDA)); + params.rng_state = reinterpret_cast(rng_state.data_ptr()); params.philox_args = philox_state; - } else { - if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) { - seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - } else { - seed_t = at::empty({}, at::dtype(at::kLong)); - offset_t = at::empty({}, at::dtype(at::kLong)); - } - } set_params_alibi(params, alibi_slopes_, batch_size, num_heads); @@ -537,7 +551,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og}); softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); } - return {out, q_padded, k_padded, v_padded, softmax_lse, seed_t, offset_t, p}; + return {out, q_padded, k_padded, v_padded, softmax_lse, rng_state, _unused, p}; } std::tuple @@ -558,6 +572,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q bool is_causal, int window_size_left, int window_size_right, + const float softcap, const bool return_softmax, std::optional gen_) { @@ -608,6 +623,8 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q const int head_size_og = sizes[2]; const int num_heads_k = paged_KV ? k.size(2) : k.size(1); + if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } + const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); const int num_blocks = !paged_KV ? 0 : k.size(0); const int page_block_size = !paged_KV ? 1 : k.size(1); @@ -671,7 +688,6 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); CHECK_DEVICE(out); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); - CHECK_SHAPE(out, total_q, num_heads, head_size_og); CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og); if (seqlenq_ngroups_swapped) { out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og}); @@ -683,7 +699,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size = round_multiple(head_size_og, 8); - const int head_size_rounded = round_multiple(head_size, 32); + const int head_size_rounded = round_multiple(head_size, 32) < 224 ? round_multiple(head_size, 32) : 256; const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); @@ -693,7 +709,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q auto opts = q.options(); - auto softmax_lse = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + auto softmax_lse = at::empty({num_heads, total_q}, opts.dtype(at::kFloat)); at::Tensor p; // Only return softmax if there's dropout to reduce compilation time if (return_softmax) { @@ -724,7 +740,10 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q softmax_scale, window_size_left, window_size_right, - seqlenq_ngroups_swapped); + softcap, + seqlenq_ngroups_swapped, + /*unpadded_lse*/true); + params.total_q = total_q; if (paged_KV) { params.block_table = block_table.data_ptr(); params.block_table_batch_stride = block_table.stride(0); @@ -741,12 +760,14 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts); } - // We want to checkpoint and save the RNG state for backward if dropout - // We get the default generator and return the seed and offset which will - // be used in the backward function - auto gen = at::get_generator_or_default(std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); - at::Tensor seed_t, offset_t; + // [Note] BC breaking change to flash seed/offset + // Previously: Used separate tensors for philox_seed and philox_offset, sometimes on CPU, sometimes on CUDA + // FlashAttention change: Now uses a single uint64_t[2] tensor on device containing both seed and offset + // Implementation: Renamed "seed" → "rng_state" (contains both seed+offset) and "offset" → "_unused" + auto rng_state = at::empty({2}, at::dtype(c10::kUInt64).device(at::kCUDA)); + auto _unused = at::empty({}, at::dtype(c10::kUInt64).device(at::kCUDA)); if (p_dropout > 0.0) { + auto gen = at::get_generator_or_default(std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); // number of times random will be generated per thread, to offset philox counter in thc random // state // We use a custom RNG that increases the offset by batch_size * nheads * 32. @@ -754,26 +775,9 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); at::PhiloxCudaState philox_state = gen->philox_cuda_state(counter_offset); - if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) { - auto [seed, offset] = at::cuda::philox::unpack(philox_state); - seed_t = at::scalar_tensor(at::Scalar(static_cast(seed)), at::dtype(at::kLong)); - offset_t = at::scalar_tensor(at::Scalar(static_cast(offset)), at::dtype(at::kLong)); - } else { - seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - params.seed = seed_t.data_ptr(); - params.extragraph_offset = offset_t.data_ptr(); - } + rng_state = at::empty({2}, at::TensorOptions().dtype(c10::kUInt64).device(at::kCUDA)); + params.rng_state = reinterpret_cast(rng_state.data_ptr()); params.philox_args = philox_state; - } else { - if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) { - seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - } else { - seed_t = at::empty({}, at::dtype(at::kLong)); - offset_t = at::empty({}, at::dtype(at::kLong)); - } - } set_params_alibi(params, alibi_slopes_, batch_size, num_heads); @@ -792,16 +796,18 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q std::array size_after = {batch_size, num_heads_k * max_seqlen_q, head_size_og}; out = out.reshape(size_before).transpose(1, 2).reshape(size_after); q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after); - softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * max_seqlen_q, 1}); + softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size}); } - return {out, q_padded, k_padded, v_padded, softmax_lse, seed_t, offset_t, p}; + return {out, q_padded, k_padded, v_padded, softmax_lse, rng_state, _unused, p}; } void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { FP16_SWITCH(!params.is_bf16, [&] { HEADDIM_SWITCH(params.d, [&] { - run_mha_bwd_(params, stream); + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + run_mha_bwd_(params, stream); + }); }); }); } @@ -822,6 +828,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si const bool is_causal, int window_size_left, int window_size_right, + const float softcap, const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset) { @@ -883,7 +890,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size_rounded = round_multiple(head_size, 32); + const int head_size_rounded = round_multiple(head_size, 32) < 224 ? round_multiple(head_size, 32) : 256; const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); @@ -982,21 +989,17 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si softmax_scale, window_size_left, window_size_right, - deterministic); + softcap, + deterministic, + /*unpadded_lse*/false); params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0); auto launch = &run_mha_bwd; at::PhiloxCudaState philox_args; + if (is_dropout) { - if (at::cuda::currentStreamCaptureStatus() == - at::cuda::CaptureStatus::None) - { - philox_args = at::PhiloxCudaState(*philox_seed.data_ptr(), *philox_offset.data_ptr()); - } else { // dropout + capture - philox_args = at::PhiloxCudaState( - philox_seed.data_ptr(), philox_offset.data_ptr(), 0); - } + params.rng_state = philox_seed.data_ptr(); } params.philox_args = philox_args; @@ -1025,7 +1028,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &out, // total_q x num_heads x head_size - const at::Tensor &softmax_lse, // b x h x s softmax logsumexp + const at::Tensor &softmax_lse, // h x total_q, softmax logsumexp std::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i std::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i std::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i @@ -1040,6 +1043,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size const bool is_causal, int window_size_left, int window_size_right, + const float softcap, const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset) @@ -1107,7 +1111,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size_rounded = round_multiple(head_size, 32); + const int head_size_rounded = round_multiple(head_size, 32) < 224 ? round_multiple(head_size, 32) : 256; const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); @@ -1162,7 +1166,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size at::cuda::CUDAGuard device_guard{(char)q.get_device()}; auto opts = q.options(); - auto softmax_d = at::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat)); + auto softmax_d = at::empty({num_heads, total_q + 128 * batch_size}, opts.dtype(at::kFloat)); at::Tensor dq_accum; if (loop) { // We don't want to allocate dq_accum of size (batch, seqlen_q_rounded, num_heads, head_size_rounded) @@ -1173,6 +1177,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size // cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will // be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally // allowed to do. So we won't have to do any bound checking, and performance should stay the same. + // Same holds for softmax_d, since LSE is stored in unpadded format. if (!deterministic) { dq_accum = at::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); } else { @@ -1218,21 +1223,17 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size softmax_scale, window_size_left, window_size_right, - deterministic); + softcap, + deterministic, + /*unpadded_lse*/true); + params.total_q = total_q;; params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0); auto launch = &run_mha_bwd; at::PhiloxCudaState philox_args; if (is_dropout) { - if (at::cuda::currentStreamCaptureStatus() == - at::cuda::CaptureStatus::None) - { - philox_args = at::PhiloxCudaState(*philox_seed.data_ptr(), *philox_offset.data_ptr()); - } else { // dropout + capture - philox_args = at::PhiloxCudaState( - philox_seed.data_ptr(), philox_offset.data_ptr(), 0); - } + params.rng_state = philox_seed.data_ptr(); } params.philox_args = philox_args; @@ -1273,6 +1274,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he bool is_causal, int window_size_left, int window_size_right, + const float softcap, bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 int num_splits ) { @@ -1385,7 +1387,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size = round_multiple(head_size_og, 8); - const int head_size_rounded = round_multiple(head_size, 32); + const int head_size_rounded = round_multiple(head_size, 32) < 224 ? round_multiple(head_size, 32) : 256; const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); @@ -1413,7 +1415,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he /*p_dropout=*/0.f, softmax_scale, window_size_left, - window_size_right); + window_size_right, + softcap + ); at::Tensor k, v, k_padded, v_padded; if (k_.has_value()) { diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.h b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.h index ea5f577d5a2..f5ba2c117d9 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.h @@ -1,10 +1,11 @@ #pragma once #include +#include #include #include -namespace pytorch_flash { +namespace FLASH_NAMESPACE { TORCH_API std::tuple @@ -18,6 +19,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head bool is_causal, int window_size_left, int window_size_right, + const float softcap, const bool return_softmax, std::optional gen_); @@ -39,6 +41,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q bool is_causal, int window_size_left, int window_size_right, + const float softcap, const bool return_softmax, std::optional gen_); @@ -59,6 +62,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si const bool is_causal, int window_size_left, int window_size_right, + const float softcap, const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset); @@ -84,8 +88,9 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size const bool is_causal, int window_size_left, int window_size_right, + const float softcap, const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset); -} // namespace pytorch_flash +} // namespace FLASH_NAMESPACE diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h deleted file mode 100644 index 3b06caa89e4..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h +++ /dev/null @@ -1,827 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include -#include - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include - -namespace pytorch_flash { - -using namespace cute; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTE_HOST_DEVICE -auto -make_tiled_copy_B_warpcontiguousN(Copy_Atom const& copy_atom, - TiledMMA const& tiled_mma) { - constexpr int TileShape_N = decltype(tiled_mma.template tile_size_mnk<1>())::value; - constexpr int TileShape_K = decltype(tiled_mma.template tile_size_mnk<2>())::value; - using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; - constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value; - // Divide by 2 because right now we always use 2 for the ValLayout - constexpr int kNWarpsN = TileShape_N / AtomShape_N / 2; - constexpr int MMAStride_N = MMA_N * AtomShape_N * 2; - // This gives the correct layout, idk why. - // auto t = make_tile(Layout, _2>, - // Stride, _8> >{}, - // auto t = make_tile(Layout, - // Stride<_1, _64, _8> >{}, - auto t = make_tile(Layout, Int, _2>, // (8, 2, 2) or (8, 4, 2) - Stride<_1, Int, _8> >{}, // (1, 64, 8) or (1, 32, 8) - make_layout(Int{})); - // if (cute::thread0()) {printf("make_tiled_copy_B_warpcontiguousN "); print(t); printf("\n"); } - return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutB_TV(), t); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTE_HOST_DEVICE -auto -make_tiled_copy_C_warpcontiguousN(Copy_Atom const& copy_atom, - TiledMMA const& tiled_mma) { - constexpr int TileShape_M = decltype(tiled_mma.template tile_size_mnk<0>())::value; - constexpr int TileShape_N = decltype(tiled_mma.template tile_size_mnk<1>())::value; - using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; - constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value; - // Divide by 2 because right now we always use 2 for the ValLayout - constexpr int kNWarpsN = TileShape_N / AtomShape_N / 2; - constexpr int MMAStride_N = MMA_N * AtomShape_N * 2; - auto t = make_tile(make_layout(Int{}), - Layout, Int, _2>, // (8, 2, 2) or (8, 4, 2) - Stride<_1, Int, _8> >{}); // (1, 64, 8) or (1, 32, 8) - // if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousN "); print(t); printf("\n"); } - return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, const int bidh, const int n_block) { - - using Element = typename Kernel_traits::Element; - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; - - // Shared memory. - extern __shared__ char smem_[]; - - // The thread index. - const int tidx = threadIdx.x; - - constexpr int kBlockM = Kernel_traits::kBlockM; - constexpr int kBlockN = Kernel_traits::kBlockN; - constexpr int kHeadDim = Kernel_traits::kHeadDim; - constexpr int MMA_N_SdP = kBlockN / decltype(typename Kernel_traits::TiledMmaSdP{}.template tile_size_mnk<1>())::value; - constexpr int AtomLayoutMS = Kernel_traits::AtomLayoutMSdP; - constexpr bool Double_buffer = !Kernel_traits::No_double_buffer; - - const BlockInfo binfo(params, bidb); - if (n_block * kBlockN >= binfo.actual_seqlen_k) return; - - int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM); - if (Is_local) { - m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left, kBlockM)); - } - - const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) - + (m_block_max - 1) * kBlockM * params.q_row_stride + bidh * params.q_head_stride; - const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) - + n_block * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; - const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) - + n_block * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; - const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb) - + (m_block_max - 1) * kBlockM * params.do_row_stride + bidh * params.do_head_stride; - const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) - + (m_block_max - 1) * kBlockM * params.o_row_stride + bidh * params.o_head_stride; - const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) - + (m_block_max - 1) * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; - const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) - + ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded - // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer. - + (!params.deterministic ? 0 : blockIdx.x * params.dq_accum_split_stride); - const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q - + (m_block_max - 1) * kBlockM; - const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded - + (m_block_max - 1) * kBlockM; - - Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), - Shape, Int>{}, - make_stride(params.q_row_stride, _1{})); - Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), - Shape, Int>{}, - make_stride(params.k_row_stride, _1{})); - Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), - Shape, Int>{}, - make_stride(params.v_row_stride, _1{})); - Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), - Shape, Int>{}, - make_stride(params.do_row_stride, _1{})); - Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), - Shape, Int>{}, - make_stride(params.o_row_stride, _1{})); - Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_ptr) + row_offset_dq), - Shape, Int>{}, - make_stride(params.dq_row_stride, _1{})); - Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), - Shape, Int>{}, - make_stride(params.h * params.d_rounded, _1{})); - Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), - Shape>{}, Stride<_1>{}); - Tensor gdPsum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum) + row_offset_dpsum), - Shape>{}, Stride<_1>{}); - - Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), - typename Kernel_traits::SmemLayoutQdO{}); - Tensor sQt = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposed{}); - Tensor sQtNoSwizzle = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{}); - // Double buffer for sQ - Tensor sdO = make_tensor(sQ.data() + (Double_buffer ? 2 : 1) * size(sQ), typename Kernel_traits::SmemLayoutQdO{}); - Tensor sdOt = make_tensor(sdO.data(), typename Kernel_traits::SmemLayoutQdOtransposed{}); - Tensor sdOtransposedNoSwizzle = make_tensor(sdO.data(), - typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{}); - Tensor sK = make_tensor(sdO.data() + size(sdO), typename Kernel_traits::SmemLayoutKV{}); - Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); - Tensor sKt = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposed{}); - Tensor sKtNoSwizzle = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposedNoSwizzle{}); - Tensor sdS = make_tensor(!Kernel_traits::Is_V_in_regs ? sV.data() + size(sV) : sK.data() + size(sK), - typename Kernel_traits::SmemLayoutPdS{}); - Tensor sdSt = make_tensor(sdS.data(), typename Kernel_traits::SmemLayoutPdStransposed{}); - Tensor sdStNoSwizzle = make_tensor(sdS.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}); - Tensor sP = make_tensor(sdS.data() + size(sdS), typename Kernel_traits::SmemLayoutPdS{}); - Tensor sPt = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposed{}); - Tensor sPtNoSwizzle = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}); - // sP and sdQ share the same memory so be careful - Tensor sdQ = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutdQ{}); - - typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; - auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); - using GmemTiledCopydO = std::conditional_t< - Is_first, - typename Kernel_traits::GmemTiledCopydO, - typename Kernel_traits::GmemTiledCopyQKV - >; - GmemTiledCopydO gmem_tiled_copy_dO; - auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ; - auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx); - using GmemLayoutAtomdQaccum = std::conditional_t< - !Seq_parallel, - typename Kernel_traits::GmemTiledCopydQaccum, - typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd - >; - GmemLayoutAtomdQaccum gmem_tiled_copy_dQaccum; - auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx); - - Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); - Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); - Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO); - Tensor tdOsdO = gmem_thr_copy_dO.partition_D(sdO); - Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO); - Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) - Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); - Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) - Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); - Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ); - Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum); - // if (cute::thread0()) { print(tdQgdQaccum.layout()); printf("\n"); } - // __syncthreads(); - // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx < 64) { - // printf("tidx = %d, tdQgdQaccum = 0x%p\n", tidx, tdQgdQaccum.data()); - // } - - typename Kernel_traits::TiledMmaSdP tiled_mma_sdp; - auto thr_mma_sdp = tiled_mma_sdp.get_thread_slice(tidx); - Tensor tSrQ = thr_mma_sdp.partition_fragment_A(sQ); // (MMA,MMA_N,MMA_K) - Tensor tSrK = thr_mma_sdp.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) - Tensor tdPrdO = thr_mma_sdp.partition_fragment_A(sdO); // (MMA,MMA_N,MMA_K) - Tensor tdPrV = thr_mma_sdp.partition_fragment_B(sV); // (MMA,MMA_N,MMA_K) - - typename Kernel_traits::TiledMmadKV tiled_mma_dkv; - auto thr_mma_dkv = tiled_mma_dkv.get_thread_slice(tidx); - Tensor tdKrdSt = thr_mma_dkv.partition_fragment_A(sdStNoSwizzle); // (MMA, MMA_N, MMA_N) - Tensor tdKrQt = thr_mma_dkv.partition_fragment_B(sQtNoSwizzle); // (MMA, MMA_K, MMA_N) - Tensor tdVrPt = thr_mma_dkv.partition_fragment_A(sPtNoSwizzle); // (MMA, MMA_N, MMA_N) - Tensor tdVrdO = thr_mma_dkv.partition_fragment_B(sdOtransposedNoSwizzle); // (MMA, MMA_K, MMA_N) - - typename Kernel_traits::TiledMmadQ tiled_mma_dq; - auto thr_mma_dq = tiled_mma_dq.get_thread_slice(tidx); - Tensor tdQrdS = thr_mma_dq.partition_fragment_A(sdS); // (MMA, MMA_N, MMA_N) - Tensor tdQrKt = thr_mma_dq.partition_fragment_B(sKtNoSwizzle); // (MMA, MMA_K, MMA_N) - - Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K - Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K - - // - // Copy Atom retiling - // - - auto smem_tiled_copy_QdO = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp); - auto smem_thr_copy_QdO = smem_tiled_copy_QdO.get_thread_slice(tidx); - Tensor tSsQ = smem_thr_copy_QdO.partition_S(sQ); - Tensor tdPsdO = smem_thr_copy_QdO.partition_S(sdO); - - // auto smem_thr_copy_KV = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp).get_thread_slice(tidx); - auto smem_tiled_copy_KV = make_tiled_copy_B_warpcontiguousN(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp); - auto smem_thr_copy_KV = smem_tiled_copy_KV.get_thread_slice(tidx); - Tensor tSsK = smem_thr_copy_KV.partition_S(sK); - // if (cute::thread(0, 0) && n_block == 0) { printf("sK layout: "); print(sK.layout()); printf("\n"); } - // if (cute::thread(0, 0) && n_block == 0) { print(tSsK.layout()); printf("\n"); } - Tensor tdPsV = smem_thr_copy_KV.partition_S(sV); - - // Partition sP and sdS to match the accumulator partitioning - // This has to be tiled_mma_sdp, not tiled_mma_dkv - // auto smem_thr_copy_PdS = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp).get_thread_slice(tidx); - auto smem_tiled_copy_PdS = make_tiled_copy_C_warpcontiguousN(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp); - auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(tidx); - Tensor tPsP = smem_thr_copy_PdS.partition_D(sP); // ((Atom,AtomNum),PIPE_M,PIPE_N) - // if (cute::thread(0, 0) && n_block == 0) { printf("sP layout: "); print(sP.layout()); printf("\n"); } - // if (cute::thread(0, 0) && n_block == 0) { print(tPsP.layout()); printf("\n"); } - // if (n_block == 0 && blockIdx.x == 0 && blockIdx.y == 0 && tidx < 64) { - // printf("tidx=%d, tPsP = 0x%p\n", tidx, tPsP.data()); - // } - Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS); // ((Atom,AtomNum),PIPE_M,PIPE_N) - - auto smem_tiled_copy_PdSt = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv); - auto smem_thr_copy_PdSt = smem_tiled_copy_PdSt.get_thread_slice(tidx); - Tensor tdVsPt = smem_thr_copy_PdSt.partition_S(sPt); - Tensor tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt); - - auto smem_tiled_copy_QdOt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv); - auto smem_thr_copy_QdOt = smem_tiled_copy_QdOt.get_thread_slice(tidx); - Tensor tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt); - Tensor tdKsQt = smem_thr_copy_QdOt.partition_S(sQt); - - auto smem_tiled_copy_dS = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_dq); - auto smem_thr_copy_dS = smem_tiled_copy_dS.get_thread_slice(tidx); - Tensor tdQsdS = smem_thr_copy_dS.partition_S(sdS); - - auto smem_tiled_copy_Kt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dq); - auto smem_thr_copy_Kt = smem_tiled_copy_Kt.get_thread_slice(tidx); - Tensor tdQsKt = smem_thr_copy_Kt.partition_S(sKt); - - auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq); - auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx); - Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N) - - // - // PREDICATES - // - - Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) - Tensor tQcQ = gmem_thr_copy_QKV.partition_D(cQ); - Tensor tKVcKV = gmem_thr_copy_QKV.partition_D(cKV); - - // Allocate predicate tensors for k - Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); - Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); - - // Set predicates for k bounds - if (!Is_even_K) { - #pragma unroll - for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } - #pragma unroll - for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } - } - - // Prologue - - // We'll advance gdQ and gdQaccum before the 1st read/write. - tdQgdQ.data() = tdQgdQ.data() + kBlockM * params.dq_row_stride; - tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.h * params.d_rounded; - - int m_block = m_block_max - 1; - int m_block_min = (!Is_causal && !Is_local) - ? 0 - : std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right) / kBlockM); - // If not local, we're guaranteed that m_block_min <= m_block: - // We checked earlier that n_block * kBlockN < actual_seqlen_k, so in the causal case, - // n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k < actual_seqlen_q. - // So m_block_min <= (actual_seqlen_q - 1) / kBlockM. - // Recall that m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM) = (actual_seqlen_q + kBlockM - 1) / kBlockM. - // So m_block_m - 1 = (actual_seqlen_q - 1) / kBlockM. - // We conclude that m_block_min <= m_block, so we will always have at least 1 iteration of the for loop. - // However, if local, then this possible to have some blocks of K & V not attending to any query. - // We might need to exit early and write 0 to dK and dV for those blocks. - // Otherwise we get wrong result for the case where we don't enter the for loop. - // And we might read OOB elements from gQ and gdO. - // This also covers the case where actual_seqlen_q == 0 - if ((Is_local || !Is_even_MN) && m_block < m_block_min) { - const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb) - + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride; - const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb) - + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride; - Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_ptr) + row_offset_dk), - Shape, Int>{}, - make_stride(params.dk_row_stride, _1{})); - Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_ptr) + row_offset_dv), - Shape, Int>{}, - make_stride(params.dv_row_stride, _1{})); - typename Kernel_traits::GmemTiledCopydKV gmem_tiled_copy_dKV; - auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx); - Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK); - Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV); - Tensor tdKrdK = make_tensor(shape(tdKgdK)); - Tensor tdVrdV = make_tensor(shape(tdVgdV)); - clear(tdKrdK); - clear(tdVrdV); - Tensor cdKV = make_identity_tensor(make_shape(size<0>(gdK), size<1>(gdK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) - Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); - Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKgdK))); - #pragma unroll - for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } - // Clear_OOB_K must be false since we don't want to write zeros to gmem - pytorch_flash::copy( - gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN - ); - pytorch_flash::copy( - gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN - ); - return; - } - - if (Double_buffer && m_block % 2 == 1) { // Double buffer for sQ - tQsQ.data() = tQsQ.data() + size(sQ); - tSsQ.data() = tSsQ.data() + size(sQ); - tdKsQt.data() = tdKsQt.data() + size(sQ); - } - - if ((!Is_first && !Seq_parallel) || params.deterministic) { __syncthreads(); } - - if (Kernel_traits::Is_V_in_regs) { - // Clear the smem tiles to account for predicated off loads - pytorch_flash::copy( - gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN - ); - pytorch_flash::cp_async_fence(); - } - - Tensor tdOrdO = make_fragment_like(tdOgdO); - Tensor tdOrO = make_fragment_like(tdOgO); - if (!Is_first) { - // Clear the smem tiles to account for predicated off loads - pytorch_flash::copy( - gmem_tiled_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM - ); - } else { - pytorch_flash::copy( - gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM - ); - pytorch_flash::copy( - gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM - ); - } - pytorch_flash::copy( - gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM - ); - - Tensor caccS = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n) - Tensor taccScS = thr_mma_sdp.partition_C(caccS); // (MMA,MMA_N,MMA_N) - static_assert(decltype(size<0>(taccScS))::value == 4); - // Convert to ((2, 2), MMA_N, MMA_N) then take only the row indices. - Tensor taccScS_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0); - Tensor lse = make_tensor(Shape>{}); - #pragma unroll - for (int mi = 0; mi < size(lse); ++mi) { - const int row = get<0>(taccScS_row(mi)); - lse(mi) = Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY; - } - // We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero, - // and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply - // with V (which would be zero), we're fine. However, with ALiBi, we might modify these - // scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0. - - // Tensor tKrK = make_fragment_like(tKsK); - // // cute::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, 0), tKrK); - // cute::copy(gmem_tiled_copy_QKV, tKgK, tKrK); - // // if (cute::thread(1, 0)) { print(tKrK); } - - pytorch_flash::copy( - gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN - ); - if (!Kernel_traits::Is_V_in_regs) { - pytorch_flash::copy( - gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN - ); - } - pytorch_flash::cp_async_fence(); - - // if (cute::thread0()) { print(tdOgdO.layout()); printf("\n"); print(tdOrdO); print(tdOrO); } - if (Is_first) { - cute::copy(tdOrdO, tdOsdO); - dot_do_o(tdOrdO, tdOrO, gdPsum, - Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout); - } - - if (Kernel_traits::Is_V_in_regs) { - cute::cp_async_wait<1>(); - __syncthreads(); - Tensor tdPrV_copy_view = smem_thr_copy_KV.retile_D(tdPrV); - CUTE_STATIC_ASSERT_V(size<1>(tdPsV) == size<1>(tdPrV_copy_view)); // M - cute::copy(smem_tiled_copy_KV, tdPsV, tdPrV_copy_view); - } - - const auto [seed, offset] = at::cuda::philox::unpack(params.philox_args); - pytorch_flash::Dropout dropout(seed, offset, params.p_dropout_in_uint8_t, - bidb, bidh, tidx, params.h); - - clear(acc_dv); - clear(acc_dk); - - const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; - pytorch_flash::Alibi alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q); - - for (; m_block >= m_block_min; --m_block) { - Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_N, MMA_N) - clear(acc_s); - cute::cp_async_wait<0>(); - __syncthreads(); - - Tensor dP_sum = make_fragment_like(lse); - #pragma unroll - for (int mi = 0; mi < size(lse); ++mi) { dP_sum(mi) = gdPsum(get<0>(taccScS_row(mi))); } - - // if (cute::thread0()) { print(sK); } - // Tensor tSrK_copy_view = smem_thr_copy_KV.retile_D(tSrK); - // #pragma unroll - // for (int k = 0; k < size<2>(tSrK_copy_view); ++k) { - // cute::copy(smem_tiled_copy_KV, tSsK(_, _, k), tSrK_copy_view(_, _, k)); - // } - // if (cute::thread0()) { print(tSrK); } - pytorch_flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp, - smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV); - - // Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout())); - // if (cute::thread(32, 0)) { print(scores); } - - if (Has_alibi) { - alibi.apply_alibi(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, - m_block * kBlockM + get<0>(taccScS_row(0)), AtomLayoutMS * 16); - } - - // TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond - // actual_seqlen_k, because acc_s would be some finite value for those indices. - // In the end when we multiply with K to get dQ, the corresponding values of K would be 0, - // so the result would still be correct. - // However, it's possible that the values in acc_s are so large that they overflow - // when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ. - // So we need to mask out the elements beyond actual_seqlen_k. - if (!Is_causal && !Is_local) { - if (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k) { - pytorch_flash::apply_mask(scores, binfo.actual_seqlen_k, - n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16); - } - } else if (Is_causal) { - // Putting this causal masking right after acc_s is *much* slower for some reason. - // TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short - // (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking. - // But we still want to mask out elements beyond actual_seqlen_k. - if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - || (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) { - pytorch_flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, - binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), - binfo.actual_seqlen_q, - // binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4, - AtomLayoutMS * 16); - } - } else if (Is_local) { - if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right - || (m_block + 1) * kBlockM >= n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left - || (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) { - pytorch_flash::apply_mask_local(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, - binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), - binfo.actual_seqlen_q, AtomLayoutMS * 16, - params.window_size_left, params.window_size_right); - } - - } - - // if (cute::thread(32, 0)) { print(scores); } - // Compute the exponential value. - pytorch_flash::scale_apply_exp2(scores, lse, params.scale_softmax_log2); - if constexpr (Is_dropout) { - int warp_id = tidx / 32; - int block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS; - // Need col to be multiples of 32, since we're doing dropout with block of 16 x 32 - static_assert(MMA_N_SdP % 2 == 0); - int block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2); - dropout.template apply_dropout( - acc_s, block_row_idx, block_col_idx, AtomLayoutMS - ); - } - // Convert scores from fp32 to fp16/bf16 - Tensor rP = !Is_dropout - ? pytorch_flash::convert_type(acc_s) - : pytorch_flash::convert_type_relu(acc_s); - // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_N, MMA_N / 2) - // if using m16n8k16 or (4, MMA_N, MMA_N) if using m16n8k8. - Tensor tPrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); - Tensor tPaP = smem_thr_copy_PdS.retile_S(tPrP); // ((Atom,AtomNum), MMA_N, MMA_N) - cute::copy(smem_tiled_copy_PdS, tPaP, tPsP); - // if (cute::thread0()) { print(tPaP); } - // __syncthreads(); - // if (cute::thread0()) { print(sP); } - - Tensor acc_dp = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_N, MMA_N) - CUTE_STATIC_ASSERT_V(size<0>(acc_dp) == size<0>(acc_s)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(acc_dp) == size<1>(acc_s)); // MMA - CUTE_STATIC_ASSERT_V(size<2>(acc_dp) == size<2>(acc_s)); // MMA - - clear(acc_dp); - // Tensor acc_dp_reshaped = make_tensor(acc_dp.data(), pytorch_flash::convert_layout_acc_rowcol(acc_dp.layout())); - // #pragma unroll - // for (int mi = 0; mi < size<0>(acc_dp_reshaped); ++mi) { - // #pragma unroll - // for (int ni = 0; ni < size<1>(acc_dp_reshaped); ++ni) { - // acc_dp_reshaped(mi, ni) = -dP_sum(mi); - // } - // } - - // if (cute::thread0()) { print(dP_sum); } - - pytorch_flash::gemm( - acc_dp, tdPrdO, tdPrV, tdPsdO, tdPsV, tiled_mma_sdp, - smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV - ); - - // Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N)) - Tensor dS = make_tensor(acc_dp.data(), scores.layout()); - auto pointwise_mult = [](float p, float dp, float d) { - return p * (!Is_dropout || p >= 0 ? dp - d : d); - }; - #pragma unroll - for (int mi = 0; mi < size<0>(dS); ++mi) { - #pragma unroll - for (int ni = 0; ni < size<1>(dS); ++ni) { - dS(mi, ni) = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi)); - } - } - // if (cute::thread0()) { print(dS); } - - Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // MMA, MMA_N, MMA_K - tdQgdQaccum.data() = tdQgdQaccum.data() + (-int(kBlockM * params.h * params.d_rounded)); - if (Is_first || Seq_parallel) { - clear(acc_dq); - } else { - // Reshape acc_dq from (4, 1, 2) to (4, 2, 1) to write to gdQaccum - Tensor acc_dq_reshaped = make_tensor(acc_dq.data(), - make_layout(get<0>(acc_dq.layout()), - get<2>(acc_dq.layout()), - get<1>(acc_dq.layout()))); - cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, acc_dq_reshaped); - } - - if (Double_buffer && m_block > m_block_min) { - // Double buffer for sQ - const int sQ_offset = m_block % 2 == 0 ? size(sQ) : -size(sQ); - tQsQ.data() = tQsQ.data() + sQ_offset; - tSsQ.data() = tSsQ.data() + sQ_offset; - // Advance gQ - tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride)); - pytorch_flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ); - pytorch_flash::cp_async_fence(); - } - - Tensor dS_reshaped = make_tensor(dS.data(), acc_dp.layout()); - // Convert dS from fp32 to fp16 - Tensor tdSrdS = pytorch_flash::convert_type(dS_reshaped); - // if (cute::thread0()) { print(tPrP); } - Tensor tdSadS = smem_thr_copy_PdS.retile_S(tdSrdS); // ((Atom,AtomNum), MMA_N, MMA_N) - cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS); - __syncthreads(); - - // Layout p_l = tPrP.layout(); - // Tensor tdVrPt = make_tensor(tPrP.data(), make_layout(get<0>(p_l), get<2>(p_l), get<1>(p_l))); - // pytorch_flash::gemm_rs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt); - // Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout()); - // pytorch_flash::gemm_rs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt); - pytorch_flash::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv, - smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt); - // if (cute::thread0() && n_block == 0 && m_block == 0) { print(tdVrPt); } - // if (cute::thread0()) { print(acc_dv); } - - __syncthreads(); // Need syncthreads since we're writing to the same sdO location - - if (m_block > m_block_min) { - // Advance gdO - tdOgdO.data() = tdOgdO.data() + (-int(kBlockM * params.do_row_stride)); - if (Is_first) { - tdOgO.data() = tdOgO.data() + (-int(kBlockM * params.o_row_stride)); - pytorch_flash::copy(gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ); - pytorch_flash::copy(gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ); - } else { - pytorch_flash::copy(gmem_tiled_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ); - pytorch_flash::cp_async_fence(); - } - } - - pytorch_flash::gemm(acc_dq, tdQrdS, tdQrKt, tdQsdS, tdQsKt, tiled_mma_dq, - smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt); - // if (cute::thread0()) { print(acc_dq); } - - if (m_block > m_block_min) { - gLSE.data() = gLSE.data() + (-int(kBlockM)); - #pragma unroll - for (int mi = 0; mi < size(lse); ++mi) { lse(mi) = gLSE(get<0>(taccScS_row(mi))); } - gdPsum.data() = gdPsum.data() + (-int(kBlockM)); - } - - if (!Is_last) { - // Reshape acc_dq from (4, 1, 2) to (4, 2, 1) to write to gdQaccum - Tensor acc_dq_reshaped = make_tensor(acc_dq.data(), - make_layout(get<0>(acc_dq.layout()), - get<2>(acc_dq.layout()), - get<1>(acc_dq.layout()))); - if (!Seq_parallel) { - cute::copy(gmem_tiled_copy_dQaccum, acc_dq_reshaped, tdQgdQaccum); - } else { - // if (cute::thread0()) { print(acc_dq.layout()); printf("\n"); print(acc_dq_reshaped.layout()); printf("\n"); print(tdQgdQaccum.layout()); printf("\n"); } - CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum)); - #pragma unroll - for (int i = 0; i < size(acc_dq); ++i) { atomicAdd(&tdQgdQaccum(i), acc_dq(i)); } - } - } else { - #pragma unroll - for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; } - // Convert acc_dq from fp32 to fp16 - Tensor rdQ = pytorch_flash::convert_type(acc_dq); - Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N) - cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); - } - - pytorch_flash::gemm(acc_dk, tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tiled_mma_dkv, - smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt); - // if (cute::thread0()) { print(acc_dk); } - if (Double_buffer) { // Double buffer for sQ - tdKsQt.data() = tdKsQt.data() + (m_block % 2 == 0 ? size(sQ) : -size(sQ)); - } - if (!Double_buffer && m_block > m_block_min) { - __syncthreads(); - // Advance gQ - tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride)); - pytorch_flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ); - pytorch_flash::cp_async_fence(); - } - - if (Is_first && m_block > m_block_min) { - cute::copy(tdOrdO, tdOsdO); - dot_do_o(tdOrdO, tdOrO, gdPsum, - Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout); - } - - if (Is_last) { - __syncthreads(); - Tensor tdQrdQ = make_tensor(shape(tdQgdQ)); - cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ); - tdQgdQ.data() = tdQgdQ.data() + (-int(kBlockM * params.dq_row_stride)); - Tensor cdQ = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ); - #pragma unroll - for (int m = 0; m < size<1>(tdQgdQ); ++m) { - if (Is_even_MN || get<0>(tdQcdQ(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) { - cute::copy(gmem_tiled_copy_dQ, tdQrdQ(_, m, _), tdQgdQ(_, m, _)); - } - } - } - - } - - // Epilogue - - if (Is_dropout) { - #pragma unroll - for (int i = 0; i < size(acc_dv); ++i) { acc_dv(i) *= params.rp_dropout; } - } - #pragma unroll - for (int i = 0; i < size(acc_dk); ++i) { acc_dk(i) *= params.scale_softmax_rp_dropout; } - - // Convert acc_dv from fp32 to fp16 - Tensor rdK = pytorch_flash::convert_type(acc_dk); - Tensor rdV = pytorch_flash::convert_type(acc_dv); - - Tensor sdK = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K) - Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K) - - // Partition sdV and sdK to match the accumulator partitioning - auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv); - auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx); - Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N) - Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N) - Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N) - Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N) - - // We need syncthreads here since we're writing to the same location as sK and sV. - // Without syncthreads, some thread might modify the location of sK while another thread - // is reading it for dQ gemm, leading to a race condition. - // If Is_last, there's already a __syncthreads() at the end of the loop. - if (!Is_last) { __syncthreads(); } - - cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK); - cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV); - - const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb) - + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride; - const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb) - + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride; - Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_ptr) + row_offset_dk), - Shape, Int>{}, - make_stride(params.dk_row_stride, _1{})); - Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_ptr) + row_offset_dv), - Shape, Int>{}, - make_stride(params.dv_row_stride, _1{})); - - typename Kernel_traits::GmemTiledCopydKV gmem_tiled_copy_dKV; - auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx); - Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK); - Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV); - - __syncthreads(); - Tensor tdKrdK = make_tensor(shape(tdKgdK)); - cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK); - Tensor tdVrdV = make_tensor(shape(tdVgdV)); - cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV); - Tensor cdKV = make_identity_tensor(make_shape(size<0>(sdK), size<1>(sdK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) - Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); - Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKgdK))); - #pragma unroll - for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } - // Clear_OOB_K must be false since we don't want to write zeros to gmem - pytorch_flash::copy( - gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN - ); - pytorch_flash::copy( - gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN - ); - -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void compute_dq_dk_dv(const Params ¶ms) { - - // The block index for the batch. - const int bidb = blockIdx.x; - // const int bidb = blockIdx.y; - // The block index for the head. - const int bidh = blockIdx.y; - // const int bidh = blockIdx.z; - // The thread index. - const int tidx = threadIdx.x; - - const int n_block_max = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; - if (n_block_max == 1) { - compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); - } else { - // Iterating backward from n_block_max - 1 to 0 might save 1 register - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block_max - 1); - for (int n_block = n_block_max - 2; n_block > 0; n_block--) { - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); - } - compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) { - - // The block index for the batch. - const int bidb = blockIdx.y; - // The block index for the head. - const int bidh = blockIdx.z; - - // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer. - for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) { - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h deleted file mode 100644 index dd3bfa1bad7..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h +++ /dev/null @@ -1,338 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include - -#include -#include -#include -#include - -namespace pytorch_flash { - -// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -#define ARCH_SUPPORTS_FLASH -#endif - -#if defined(ARCH_SUPPORTS_FLASH) && defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 11 && \ - defined(__CUDACC_VER_MINOR__) && __CUDACC_VER_MINOR__ >= 8 -#define KERNEL_PARAM_MODIFIER __grid_constant__ -#else -#define KERNEL_PARAM_MODIFIER -#endif - -// Define a macro for unsupported architecture handling to centralize the error message -#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!"); - -// Use a macro to clean up kernel definitions -#define DEFINE_FLASH_BACKWARD_KERNEL(kernelName, ...) \ -template \ -__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params) - -DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K) { - #if defined(ARCH_SUPPORTS_FLASH) - pytorch_flash::compute_dq_dk_dv(params); - #else - FLASH_UNSUPPORTED_ARCH - #endif -} - -DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K) { - #if defined(ARCH_SUPPORTS_FLASH) - static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false - pytorch_flash::compute_dq_dk_dv_seqk_parallel(params); - #else - FLASH_UNSUPPORTED_ARCH - #endif -} - -template -__global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) { - pytorch_flash::compute_dot_do_o(params); -} - -template -__global__ void flash_bwd_clear_dkvaccum_kernel(const Flash_bwd_params params) { - pytorch_flash::clear_dKVaccum(params); -} - -template -__global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) { - pytorch_flash::convert_dQ(params, nsplits); -} - -template -__global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) { - pytorch_flash::convert_dKV(params); -} - -template -void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) { - const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; - dim3 grid_m(num_m_block, params.b, params.h); - const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; - int gridDimx = num_n_block; - if (params.deterministic) { - auto dprops = at::cuda::getCurrentDeviceProperties(); - gridDimx = (dprops->multiProcessorCount + params.b * params.h - 1) / (params.b * params.h); - } - dim3 grid_n(gridDimx, params.b, params.h); - - if (!params.deterministic) { - flash_bwd_dot_do_o_kernel<<>>(params); - } else { - flash_bwd_dot_do_o_kernel<<>>(params); - } - C10_CUDA_KERNEL_LAUNCH_CHECK(); - - // We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not - // a multiple of kBlockN, we'll need to apply mask in the loop. - const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0; - const bool is_even_K = params.d == Kernel_traits::kHeadDim; - constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock; - // printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv); - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { - EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { - LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] { - ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If head dim > 128, set IsEvenMNConst to false to reduce number of templates - // If Is_local, set Is_causal to false - auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; - // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; - if (smem_size_dq_dk_dv >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); - }); - }); - }); - }); - - auto kernel_dq = &flash_bwd_convert_dq_kernel; - if (Kernel_traits::kSmemdQSize >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize)); - } - kernel_dq<<>>(params, !params.deterministic ? 1 : gridDimx); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template -void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { -#ifndef FLASHATTENTION_DISABLE_BACKWARD - run_flash_bwd_seqk_parallel(params, stream); -#endif -} - -template -void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream) { - constexpr static int Headdim = 32; - int device; - cudaGetDevice(&device); - int max_smem_per_block; - cudaError status_ = cudaDeviceGetAttribute( - &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); - if (status_ != cudaSuccess) { - C10_CUDA_CHECK(status_); - } - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB - if constexpr(!Is_dropout) { // We can afford more registers to keep V in registers - run_flash_bwd, Is_dropout>(params, stream); - } else { - run_flash_bwd, Is_dropout>(params, stream); - } - } else { // 96 KB - run_flash_bwd, Is_dropout>(params, stream); - } - }); -} - -template -void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) { - constexpr static int Headdim = 64; - int device; - cudaGetDevice(&device); - int max_smem_per_block; - cudaError status_ = cudaDeviceGetAttribute( - &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); - if (status_ != cudaSuccess) { - C10_CUDA_CHECK(status_); - } - // printf("max_smem_per_block = %d\n", max_smem_per_block); - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - // Changing AtomLayoutMdQ from 2 to 4 takes the same time - // run_flash_bwd>(params, stream); - // run_flash_bwd>(params, stream); - // run_flash_bwd>(params, stream); - // run_flash_bwd, Is_dropout>(params, stream); - // This is slightly faster. We want to split M more so we need fewer registers to store LSE. - if (max_smem_per_block >= 144 * 1024) { - run_flash_bwd, Is_dropout>(params, stream); - // This has a lot of register spilling - // run_flash_bwd, Is_dropout>(params, stream); - } else { - // if (params.h == params.h_k) { - // run_flash_bwd, Is_dropout>(params, stream); - run_flash_bwd, Is_dropout>(params, stream); - // run_flash_bwd, Is_dropout>(params, stream); - // run_flash_bwd, Is_dropout>(params, stream); - // } else { - // } - } - }); - // run_flash_bwd>(params, stream); - // run_flash_bwd>(params, stream); - // run_flash_bwd>(params, stream); - // run_flash_bwd>(params, stream); - // M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times - // run_flash_bwd>(params, stream); - // run_flash_bwd>(params, stream); - // run_flash_bwd>(params, stream); - - // run_flash_bwd>(params, stream); -} - -template -void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) { - constexpr static int Headdim = 96; - int device; - cudaGetDevice(&device); - int max_smem_per_block; - cudaError status_ = cudaDeviceGetAttribute( - &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); - if (status_ != cudaSuccess) { - C10_CUDA_CHECK(status_); - } - // printf("max_smem_per_block = %d\n", max_smem_per_block); - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - if (max_smem_per_block >= 116 * 1024) { - if constexpr(!Is_dropout) { // 92KB - run_flash_bwd, Is_dropout>(params, stream); - } else { // 116 KB - // This is faster for dropout since we don't have many registers to spare - run_flash_bwd, Is_dropout>(params, stream); - } - } else { - run_flash_bwd, Is_dropout>(params, stream); - } - }); -} - -template -void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) { - constexpr static int Headdim = 128; - int device; - cudaGetDevice(&device); - int max_smem_per_block; - cudaError status_ = cudaDeviceGetAttribute( - &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); - if (status_ != cudaSuccess) { - C10_CUDA_CHECK(status_); - } - // printf("max_smem_per_block = %d\n", max_smem_per_block); - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - // run_flash_bwd>(params, stream); - // This is faster, in the case of sequence-parallel bwd (where we need fewer registers). - // Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why. - // run_flash_bwd>(params, stream); - if (max_smem_per_block >= 144 * 1024) { - run_flash_bwd, Is_dropout>(params, stream); - // run_flash_bwd_seqk_parallel, Is_dropout>(params, stream); - // run_flash_bwd_seqk_parallel, Is_dropout>(params, stream); - // run_flash_bwd, Is_dropout>(params, stream); - // run_flash_bwd, Is_dropout>(params, stream); - // run_flash_bwd, Is_dropout>(params, stream); - } else { - // run_flash_bwd, Is_dropout>(params, stream); - run_flash_bwd, Is_dropout>(params, stream); - } - // run_flash_bwd>(params, stream); - - // run_flash_bwd>(params, stream); - }); -} - -template -void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream) { - constexpr static int Headdim = 160; - int device; - cudaGetDevice(&device); - int max_smem_per_block; - cudaError status_ = cudaDeviceGetAttribute( - &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); - if (status_ != cudaSuccess) { - C10_CUDA_CHECK(status_); - } - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - if (max_smem_per_block >= 116 * 1024) { - run_flash_bwd, Is_dropout>(params, stream); - } else { - run_flash_bwd, Is_dropout>(params, stream); - } - }); -} - -template -void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) { - constexpr static int Headdim = 192; - int device; - cudaGetDevice(&device); - int max_smem_per_block; - cudaError status_ = cudaDeviceGetAttribute( - &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); - if (status_ != cudaSuccess) { - C10_CUDA_CHECK(status_); - } - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - if (max_smem_per_block >= 136 * 1024) { - run_flash_bwd, Is_dropout>(params, stream); - } else { - run_flash_bwd, Is_dropout>(params, stream); - } - }); -} - -template -void run_mha_bwd_hdim224(Flash_bwd_params ¶ms, cudaStream_t stream) { - constexpr static int Headdim = 224; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - run_flash_bwd, Is_dropout>(params, stream); - }); -} - -template -void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream) { - constexpr static int Headdim = 256; - int device; - cudaGetDevice(&device); - int max_smem_per_block; - cudaError status_ = cudaDeviceGetAttribute( - &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); - if (status_ != cudaSuccess) { - C10_CUDA_CHECK(status_); - } - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - if (max_smem_per_block >= 176 * 1024) { // H100 - run_flash_bwd, Is_dropout>(params, stream); - } else if (max_smem_per_block >= 144 * 1024) { // A100, we don't do double buffering to save smem - run_flash_bwd, Is_dropout>(params, stream); - } else { // sm86 and sm89, max smem is 99 KB. Only works without dropout. V in regs and no double buffering. - if constexpr (!Is_dropout) { - run_flash_bwd, false>(params, stream); - } - } - }); -} - - -}; // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_preprocess_kernel.h b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_preprocess_kernel.h deleted file mode 100644 index 166446e047f..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_preprocess_kernel.h +++ /dev/null @@ -1,377 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include - -#include -#include -#include - -#include -#include -#include - -namespace pytorch_flash { - -using namespace cute; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void dot_do_o(Tensor const &do_, Tensor const &o, - Tensor &dP_sum, const int gdP_col_stride, const float scale) { - static_assert(Layout0::rank == 3, "Only support 3D Tensor"); - static_assert(Layout1::rank == 1, "Only support 1D Tensor"); - CUTE_STATIC_ASSERT_V(do_.layout() == o.layout()); - // Reshape do_ and o from (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, 8 * kHeadDim / 64) - // The last coordinate is the "page". - Tensor do_reshaped = make_tensor(do_.data(), make_layout(get<1>(do_.layout()), - make_layout(get<0>(do_.layout()), - get<2>(do_.layout())))); - Tensor o_reshaped = make_tensor(o.data(), do_reshaped.layout()); - Tensor do_fp32 = pytorch_flash::convert_type(do_reshaped); - Tensor o_fp32 = pytorch_flash::convert_type(o_reshaped); - #pragma unroll - for (int mi = 0; mi < size<0>(do_reshaped); ++mi) { - float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0); - #pragma unroll - for (int ni = 1; ni < size<1>(do_reshaped); ni++) { - dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni); - } - pytorch_flash::SumOp sum_op; - dP_sum_cur = pytorch_flash::Allreduce::run(dP_sum_cur, sum_op) * scale; - if (threadIdx.x % THREADS_PER_ROW == 0) { - dP_sum(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum_cur; - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel. -// This is used in the case where we want to parallelize the backward across seqlen_k. -template -inline __device__ void compute_dot_do_o(const Params ¶ms) { - using Element = typename Kernel_traits::Element; - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; - - const int m_block = blockIdx.x; - // The block index for the batch. - const int bidb = blockIdx.y; - // The block index for the head. - const int bidh = blockIdx.z; - // The thread index. - const int tidx = threadIdx.x; - - constexpr int kBlockM = Kernel_traits::kBlockM; - constexpr int kHeadDim = Kernel_traits::kHeadDim; - - const BlockInfo binfo(params, bidb); - if (m_block * kBlockM >= binfo.actual_seqlen_q) return; - - const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb) - + m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride; - const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) - + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; - const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) - + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; - const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM; - - Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), - Shape, Int>{}, - make_stride(params.do_row_stride, _1{})); - Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), - Shape, Int>{}, - make_stride(params.o_row_stride, _1{})); - Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), - Shape, Int>{}, - make_stride(params.h * params.d_rounded, _1{})); - Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum) + row_offset_dpsum), - Shape>{}, Stride<_1>{}); - - typename Kernel_traits::GmemTiledCopydO gmem_tiled_copy_dO; - auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx); - // TODO: careful, we're zeroing out dQaccum with type float4, but when - // we do atomicAdds, we use type float. The layouts are different. Check this. - typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum; - auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx); - - Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO); - Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO); - Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum); - - Tensor cdO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor tdOcdO = gmem_thr_copy_dO.partition_S(cdO); - - // Allocate predicate tensors for k - Tensor tdOpdO = make_tensor(make_shape(size<2>(tdOgdO))); - // Set predicates for k bounds - #pragma unroll - for (int k = 0; k < size(tdOpdO); ++k) {tdOpdO(k) = get<1>(tdOcdO(0, 0, k)) < params.d;} - - Tensor tdOrdO = make_fragment_like(tdOgdO); - Tensor tdOrO = make_fragment_like(tdOgO); - pytorch_flash::copy( - gmem_tiled_copy_dO, tdOgdO, tdOrdO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM - ); - pytorch_flash::copy( - gmem_tiled_copy_dO, tdOgO, tdOrO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM - ); - // By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final - // results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here, - // so that (dP - dP_sum) is on the same scale. - dot_do_o(tdOrdO, tdOrO, dP_sum, - Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout); - if (Clear_dQaccum) { - // We're actually not zero'ing out all of dQaccum, but only the part that we're going to - // do atomicAdds on. - Tensor zero = make_fragment_like(tdQgdQaccum); - clear(zero); - cute::copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void clear_dKVaccum(const Params ¶ms) { - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; - - const int n_block = blockIdx.x; - // The block index for the batch. - const int bidb = blockIdx.y; - // The block index for the head. - const int bidh = blockIdx.z; - // The thread index. - const int tidx = threadIdx.x; - - constexpr int kBlockN = Kernel_traits::kBlockN; - constexpr int kHeadDim = Kernel_traits::kHeadDim; - - const BlockInfo binfo(params, bidb); - if (n_block * kBlockN >= binfo.actual_seqlen_k) return; - - const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + n_block * kBlockN) * params.d_rounded; - - Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr) + row_offset_dkv_accum), - Shape, Int>{}, Stride, _1>{}); - Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_accum_ptr) + row_offset_dkv_accum), - Shape, Int>{}, Stride, _1>{}); - - typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum; - auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx); - Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_D(gdKaccum); - Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_D(gdVaccum); - Tensor zero = make_fragment_like(tdKgdKaccum); - clear(zero); - cute::copy(gmem_tiled_copy_dKVaccum, zero, tdKgdKaccum); - cute::copy(gmem_tiled_copy_dKVaccum, zero, tdVgdVaccum); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Convert dQ from dQaccum (in float) to fp16/bf16. -// This is used in the case where we want to parallelize the backward across seqlen_k. -template -inline __device__ void convert_dQ(const Params ¶ms, const int nsplits) { - using Element = typename Kernel_traits::Element; - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; - - // Shared memory. - extern __shared__ char smem_[]; - - const int m_block = blockIdx.x; - // The block index for the batch. - const int bidb = blockIdx.y; - // The block index for the head. - const int bidh = blockIdx.z; - // The thread index. - const int tidx = threadIdx.x; - - constexpr int kBlockM = Kernel_traits::kBlockM; - constexpr int kHeadDim = Kernel_traits::kHeadDim; - - const BlockInfo binfo(params, bidb); - if (m_block * kBlockM >= binfo.actual_seqlen_q) return; - - const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) - + m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; - const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) - + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; - - Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_ptr) + row_offset_dq), - Shape, Int>{}, - make_stride(params.dq_row_stride, _1{})); - Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), - Shape, Int>{}, - make_stride(params.h * params.d_rounded, _1{})); - - Tensor sdQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), - typename Kernel_traits::SmemLayoutdQ{}); - - typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ; - auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dQaccum; - auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx); - - typename Kernel_traits::TiledMmadQ tiled_mma_dq; - auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq); - auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx); - Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N) - - Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ); - Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum); - - Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // MMA, MMA_N, MMA_K - CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum)); - - Tensor tdQrdQaccum = make_fragment_like(tdQgdQaccum); - clear(acc_dq); - for (int s = 0; s < nsplits; ++s) { - cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, tdQrdQaccum); - #pragma unroll - for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) += tdQrdQaccum(i); } - tdQgdQaccum.data() = tdQgdQaccum.data() + params.dq_accum_split_stride; - } - #pragma unroll - for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; } - // Convert acc_dq from fp32 to fp16 - Tensor rdQ = pytorch_flash::convert_type(acc_dq); - Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N) - cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); - __syncthreads(); - Tensor tdQrdQ = make_tensor(shape(tdQgdQ)); - cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ); - - Tensor cdQ = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ); - Tensor tdQpdQ = make_tensor(make_shape(size<2>(tdQgdQ))); - #pragma unroll - for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; } - // Clear_OOB_K must be false since we don't want to write zeros to gmem - pytorch_flash::copy( - gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM - ); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Convert dK and dV from dKaccum and dVaccum (in float) to fp16/bf16. -// This is used in the case where we want to parallelize the backward across seqlen_q. -template -inline __device__ void convert_dKV(const Params ¶ms) { - using Element = typename Kernel_traits::Element; - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; - - // Shared memory. - extern __shared__ char smem_[]; - - const int n_block = blockIdx.x; - // The block index for the batch. - const int bidb = blockIdx.y; - // The block index for the head. - const int bidh = blockIdx.z; - // The thread index. - const int tidx = threadIdx.x; - - constexpr int kBlockN = Kernel_traits::kBlockN; - constexpr int kHeadDim = Kernel_traits::kHeadDim; - - const BlockInfo binfo(params, bidb); - if (n_block * kBlockN >= binfo.actual_seqlen_k) return; - - const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb) - + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride; - const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb) - + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride; - const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded - + n_block * kBlockN) * params.d_rounded; - - Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_ptr) + row_offset_dk), - Shape, Int>{}, - make_stride(params.dk_row_stride, _1{})); - Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_ptr) + row_offset_dv), - Shape, Int>{}, - make_stride(params.dv_row_stride, _1{})); - Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr) + row_offset_dkv_accum), - Shape, Int>{}, - Stride, _1>{}); - Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_accum_ptr) + row_offset_dkv_accum), - Shape, Int>{}, - Stride, _1>{}); - - Tensor sdK = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), - typename Kernel_traits::SmemLayoutdKV{}); - Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K) - - typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dKV; - auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dKVaccum; - auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx); - - typename Kernel_traits::TiledMmadKV tiled_mma_dkv; - auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv); - auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx); - Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N) - Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N) - - Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK); - Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV); - Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_S(gdKaccum); - Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_S(gdVaccum); - - Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K - Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K - CUTE_STATIC_ASSERT_V(size(acc_dk) == size(tdKgdKaccum)); - CUTE_STATIC_ASSERT_V(size(acc_dv) == size(tdVgdVaccum)); - - Tensor tdKrdKaccum = make_fragment_like(tdKgdKaccum); - Tensor tdVrdVaccum = make_fragment_like(tdVgdVaccum); - cute::copy(gmem_tiled_copy_dKVaccum, tdKgdKaccum, tdKrdKaccum); - cute::copy(gmem_tiled_copy_dKVaccum, tdVgdVaccum, tdVrdVaccum); - #pragma unroll - for (int i = 0; i < size(acc_dk); ++i) { - acc_dk(i) = tdKrdKaccum(i) * params.scale_softmax_rp_dropout; - } - #pragma unroll - for (int i = 0; i < size(acc_dv); ++i) { - acc_dv(i) = tdVrdVaccum(i) * params.rp_dropout; - } - // Convert acc_dk from fp32 to fp16 - Tensor rdK = pytorch_flash::convert_type(acc_dk); - Tensor rdV = pytorch_flash::convert_type(acc_dv); - Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N) - Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N) - cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK); - cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV); - __syncthreads(); - Tensor tdKrdK = make_tensor(shape(tdKgdK)); - Tensor tdVrdV = make_tensor(shape(tdVgdV)); - cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK); - cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV); - - Tensor cdKV = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); - Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKgdK))); - #pragma unroll - for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } - // Clear_OOB_K must be false since we don't want to write zeros to gmem - pytorch_flash::copy( - gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN - ); - pytorch_flash::copy( - gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN - ); -} - -} // namespace flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_kernel.h b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_kernel.h deleted file mode 100644 index 3a9321d913a..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_kernel.h +++ /dev/null @@ -1,1254 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include - -#include -#include -#include - - -#include -#include -#include -#include -#include -#include -#include - -namespace pytorch_flash { - -using namespace cute; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { - - using Element = typename Kernel_traits::Element; - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; - - // Shared memory. - extern __shared__ char smem_[]; - - // The thread index. - const int tidx = threadIdx.x; - - constexpr int kBlockM = Kernel_traits::kBlockM; - constexpr int kBlockN = Kernel_traits::kBlockN; - constexpr int kHeadDim = Kernel_traits::kHeadDim; - constexpr int kNWarps = Kernel_traits::kNWarps; - - auto [seed, offset] = at::cuda::philox::unpack(params.philox_args); - pytorch_flash::Dropout dropout(seed, offset, params.p_dropout_in_uint8_t, - bidb, bidh, tidx, params.h); - - // Save seed and offset for backward. If we don't have this here, the 0-th thread block might - // exit early and no one saves the rng state. - if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) { - if (params.philox_args.captured_) { - *params.seed = seed; - *params.extragraph_offset = offset; - } - } - - const BlockInfo binfo(params, bidb); - if (m_block * kBlockM >= binfo.actual_seqlen_q) return; - - const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); - int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); - if (Is_causal || Is_local) { - n_block_max = std::min(n_block_max, - cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); - // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { - // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max); - // } - } - // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0. - // Otherwise we might read OOB elements from gK and gV. - if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) { - Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) - + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)), - make_shape(binfo.actual_seqlen_q, params.h, params.d), - make_stride(params.o_row_stride, params.o_head_stride, _1{})); - Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, - make_coord(m_block, 0)); // (kBlockM, kHeadDim) - Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), - make_shape(params.b, params.h, params.seqlen_q), - make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{})); - Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape>{}, make_coord(m_block)); - - - typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; - auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); - Tensor tOgO = gmem_thr_copy_O.partition_D(gO); - Tensor tOrO = make_tensor(shape(tOgO)); - clear(tOrO); - // Construct identity layout for sO - Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - // Repeat the partitioning with identity layouts - Tensor tOcO = gmem_thr_copy_O.partition_D(cO); - Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); - if (!Is_even_K) { - #pragma unroll - for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } - } - // Clear_OOB_K must be false since we don't want to write zeros to gmem - pytorch_flash::copy( - gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM - ); - #pragma unroll - for (int m = 0; m < size<1>(tOgO); ++m) { - const int row = get<0>(tOcO(0, m, 0)); - if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; } - } - return; - } - // if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); } - - // We iterate over the blocks in reverse order. This is because the last block is the only one - // that needs masking when we read K and V from global memory. Moreover, iterating in reverse - // might save us 1 register (we just need n_block instead of both n_block and n_block_max). - - const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded - + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; - - Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) - + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), - make_shape(binfo.actual_seqlen_q, params.h, params.d), - make_stride(params.q_row_stride, params.q_head_stride, _1{})); - Tensor gQ = local_tile(mQ(_, bidh, _), Shape, Int>{}, - make_coord(m_block, 0)); // (kBlockM, kHeadDim) - Tensor mK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) - + binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)), - make_shape(binfo.actual_seqlen_k, params.h_k, params.d), - make_stride(params.k_row_stride, params.k_head_stride, _1{})); - Tensor gK = local_tile(mK(_, bidh / params.h_h_k_ratio, _), Shape, Int>{}, - make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN) - Tensor mV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) - + binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)), - make_shape(binfo.actual_seqlen_k, params.h_k, params.d), - make_stride(params.v_row_stride, params.v_head_stride, _1{})); - Tensor gV = local_tile(mV(_, bidh / params.h_h_k_ratio, _), Shape, Int>{}, - make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN) - Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast(params.p_ptr) + row_offset_p), - Shape, Int>{}, - make_stride(params.seqlen_k_rounded, _1{})); - - Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), - typename Kernel_traits::SmemLayoutQ{}); - // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; - Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)), - typename Kernel_traits::SmemLayoutKV{}); - Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); - Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); - Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); - - typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; - auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); - - Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); - Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); - Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN) - Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); - Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN) - Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); - - typename Kernel_traits::TiledMma tiled_mma; - auto thr_mma = tiled_mma.get_thread_slice(tidx); - Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) - Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) - Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) - - Tensor tSgS = thr_mma.partition_C(gP); - - Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K - - // - // Copy Atom retiling - // - - auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); - auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); - // if (cute::thread0()) {smem_thr_copy_Q.print_all();} - Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); - // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");} - - auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); - auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); - Tensor tSsK = smem_thr_copy_K.partition_S(sK); - - auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); - auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); - Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); - - // - // PREDICATES - // - - // // Allocate predicate tensors for m and n - // Tensor tQpQ = make_tensor(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); - // Tensor tKVpKV = make_tensor(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); - - // Construct identity layout for sQ and sK - Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) - // Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K) - // if (cute::thread0()) { - // print(tScQ.layout()); printf("\n"); - // for (int i = 0; i < size(tScQ); ++i) { - // printf("%d ", get<0>(tScQ(i))); - // } - // printf("\n"); - // for (int i = 0; i < size(tScQ); ++i) { - // printf("%d ", get<1>(tScQ(i))); - // } - // printf("\n"); - // } - - // Repeat the partitioning with identity layouts - Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) - - // Allocate predicate tensors for k - Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); - Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); - - // Set predicates for k bounds - if (!Is_even_K) { - #pragma unroll - for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } - #pragma unroll - for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } - } - - // Prologue - - // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - pytorch_flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, - binfo.actual_seqlen_q - m_block * kBlockM); - if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } - - // // if (cute::thread(1, 0)) { print(tQsQ); } - // // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{}); - // // if (cute::thread0()) { print(sQNoSwizzle); } - - if (Kernel_traits::Share_Q_K_smem) { - pytorch_flash::cp_async_wait<0>(); - __syncthreads(); - Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); - CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M - cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); - __syncthreads(); - } - - int n_block = n_block_max - 1; - // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. - pytorch_flash::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV, - binfo.actual_seqlen_k - n_block * kBlockN); - cute::cp_async_fence(); - // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } - // __syncthreads(); - - if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) { - pytorch_flash::cp_async_wait<1>(); - __syncthreads(); - Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); - CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M - cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); - } - - clear(acc_o); - - pytorch_flash::Softmax<2 * size<1>(acc_o)> softmax; - - const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; - pytorch_flash::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); - - // For performance reason, we separate out two kinds of iterations: - // those that need masking on S, and those that don't. - // We need masking on S for the very last block when K and V has length not multiple of kBlockN. - // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. - // We will have at least 1 "masking" iteration. - - // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to - // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. - constexpr int n_masking_steps = (!Is_causal && !Is_local) - ? 1 - : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); - #pragma unroll - for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { - Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) - clear(acc_s); - pytorch_flash::cp_async_wait<0>(); - __syncthreads(); - - // Advance gV - if (masking_step > 0) { - pytorch_flash::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); - } else { - // Clear the smem tiles to account for predicated off loads - pytorch_flash::copy( - gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN - ); - } - cute::cp_async_fence(); - cute::cp_async_fence(); - - pytorch_flash::gemm( - acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, - smem_thr_copy_Q, smem_thr_copy_K - ); - // if (cute::thread0()) { print(acc_s); } - - mask.template apply_mask( - acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 - ); - - pytorch_flash::cp_async_wait<0>(); - __syncthreads(); - if (n_block > n_block_min) { - pytorch_flash::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); - // This cp_async_fence needs to be in the if block, otherwise the synchronization - // isn't right and we get race conditions. - cute::cp_async_fence(); - } - - // TODO: when we have key_padding_mask we'll need to Check_inf - masking_step == 0 - ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) - : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); - - // Convert acc_s from fp32 to fp16/bf16 - Tensor rP = pytorch_flash::convert_type(acc_s); - int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; - int block_col_idx = n_block * (kBlockN / 32); - if (Return_softmax) { - Tensor rP_drop = make_fragment_like(rP); - cute::copy(rP, rP_drop); - dropout.template apply_dropout( - rP_drop, block_row_idx, block_col_idx, kNWarps - ); - cute::copy(rP_drop, tSgS); - tSgS.data() = tSgS.data() + (-kBlockN); - } - if (Is_dropout) { - dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps); - } - - // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); - // if (cute::thread0()) { print(tOrP); } - pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); - // if (cute::thread0()) { print(scores); } - - // This check is at the end of the loop since we always have at least 1 iteration - if (n_masking_steps > 1 && n_block <= n_block_min) { - --n_block; - break; - } - } - - // These are the iterations where we don't need masking on S - for (; n_block >= n_block_min; --n_block) { - Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) - clear(acc_s); - pytorch_flash::cp_async_wait<0>(); - __syncthreads(); - pytorch_flash::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); - cute::cp_async_fence(); - - pytorch_flash::gemm( - acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, - smem_thr_copy_Q, smem_thr_copy_K - ); - - pytorch_flash::cp_async_wait<0>(); - __syncthreads(); - if (n_block > n_block_min) { - pytorch_flash::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); - // This cp_async_fence needs to be in the if block, otherwise the synchronization - // isn't right and we get race conditions. - cute::cp_async_fence(); - } - - mask.template apply_mask( - acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 - ); - - softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); - - Tensor rP = pytorch_flash::convert_type(acc_s); - int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; - int block_col_idx = n_block * (kBlockN / 32); - if (Return_softmax) { - Tensor rP_drop = make_fragment_like(rP); - cute::copy(rP, rP_drop); - dropout.template apply_dropout( - rP_drop, block_row_idx, block_col_idx, kNWarps - ); - cute::copy(rP_drop, tSgS); - tSgS.data() = tSgS.data() + (-kBlockN); - } - if (Is_dropout) { - dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps); - } - - // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); - pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); - } - - // Epilogue - - Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, params.rp_dropout); - - // Convert acc_o from fp32 to fp16/bf16 - Tensor rO = pytorch_flash::convert_type(acc_o); - Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) - // Partition sO to match the accumulator partitioning - auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); - auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); - Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) - Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) - - // sO has the same size as sQ, so we don't need to sync here. - if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); } - - cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); - - Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) - + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)), - make_shape(binfo.actual_seqlen_q, params.h, params.d), - make_stride(params.o_row_stride, params.o_head_stride, _1{})); - Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, - make_coord(m_block, 0)); // (kBlockM, kHeadDim) - Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), - make_shape(params.b, params.h, params.seqlen_q), - make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{})); - Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape>{}, make_coord(m_block)); - - typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; - auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); - Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tOgO = gmem_thr_copy_O.partition_D(gO); - - __syncthreads(); - - Tensor tOrO = make_tensor(shape(tOgO)); - cute::copy(gmem_tiled_copy_O, tOsO, tOrO); - - Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) - static_assert(decltype(size<0>(taccOcO))::value == 4); - // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. - Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); - CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M - if (get<1>(taccOcO_row(0)) == 0) { - #pragma unroll - for (int mi = 0; mi < size(lse); ++mi) { - const int row = get<0>(taccOcO_row(mi)); - if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); } - } - } - - // Construct identity layout for sO - Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - // Repeat the partitioning with identity layouts - Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); - if (!Is_even_K) { - #pragma unroll - for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } - } - // Clear_OOB_K must be false since we don't want to write zeros to gmem - pytorch_flash::copy( - gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM - ); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { - - using Element = typename Kernel_traits::Element; - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; - - // Shared memory. - extern __shared__ char smem_[]; - - // The thread index. - const int tidx = threadIdx.x; - - constexpr int kBlockM = Kernel_traits::kBlockM; - constexpr int kBlockN = Kernel_traits::kBlockN; - constexpr int kHeadDim = Kernel_traits::kHeadDim; - constexpr int kNWarps = Kernel_traits::kNWarps; - - using GmemTiledCopyO = std::conditional_t< - !Split, - typename Kernel_traits::GmemTiledCopyO, - typename Kernel_traits::GmemTiledCopyOaccum - >; - using ElementO = std::conditional_t; - - const BlockInfo binfo(params, bidb); - // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); } - // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); } - if (m_block * kBlockM >= binfo.actual_seqlen_q) return; - - const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; - const int n_block_min = !Is_local - ? n_split_idx * n_blocks_per_split - : std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); - int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split); - if (Is_causal || Is_local) { - n_block_max = std::min(n_block_max, - cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); - } - if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0 - // We exit early and write 0 to gOaccum and -inf to gLSEaccum. - // Otherwise we might read OOB elements from gK and gV, - // or get wrong results when we combine gOaccum from different blocks. - const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) - + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; - const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q - + m_block * kBlockM) * params.d_rounded; - const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; - Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), - Shape, Int>{}, - make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); - Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), - Shape>{}, Stride<_1>{}); - - GmemTiledCopyO gmem_tiled_copy_Oaccum; - auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); - Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); - Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); - clear(tOrOaccum); - // Construct identity layout for sO - Tensor cO = make_identity_tensor(make_shape(size<0>(gOaccum), size<1>(gOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - // Repeat the partitioning with identity layouts - Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); - Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); - if (!Is_even_K) { - #pragma unroll - for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } - } - // Clear_OOB_K must be false since we don't want to write zeros to gmem - pytorch_flash::copy( - gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM - ); - #pragma unroll - for (int m = 0; m < size<1>(tOgOaccum); ++m) { - const int row = get<0>(tOcO(0, m, 0)); - if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSEaccum(row) = Split ? -INFINITY : INFINITY; } - } - return; - } - - // We iterate over the blocks in reverse order. This is because the last block is the only one - // that needs masking when we read K and V from global memory. Moreover, iterating in reverse - // might save us 1 register (we just need n_block instead of both n_block and n_block_max). - - - // We move K and V to the last block. - const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; - const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride; - const int block_table_idx = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN / params.page_block_size; - const int block_table_offset = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN - block_table_idx * params.page_block_size; - const index_t row_offset_k = block_table == nullptr - ? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) - + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride - : block_table[block_table_idx] * params.k_batch_stride + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; - const index_t row_offset_v = block_table == nullptr - ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) - + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride - : block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; - - Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), - make_shape(binfo.actual_seqlen_q, params.h, params.d), - make_stride(params.q_row_stride, params.q_head_stride, _1{})); - Tensor gQ = local_tile(mQ(_, bidh, _), Shape, Int>{}, - make_coord(m_block, 0)); // (kBlockM, kHeadDim) - Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), - Shape, Int>{}, - make_stride(params.k_row_stride, _1{})); - // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); } - Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), - Shape, Int>{}, - make_stride(params.v_row_stride, _1{})); - - Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), - typename Kernel_traits::SmemLayoutQ{}); - Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{}); - Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); - Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); - Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); - - typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; - auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); - - Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); - Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); - Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) - Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); - Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) - Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); - - typename Kernel_traits::TiledMma tiled_mma; - auto thr_mma = tiled_mma.get_thread_slice(tidx); - Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) - Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) - Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) - - Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K - - // - // Copy Atom retiling - // - - auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); - auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); - Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); - - auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); - auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); - Tensor tSsK = smem_thr_copy_K.partition_S(sK); - - auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); - auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); - Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); - - // PREDICATES - // - - // // Allocate predicate tensors for m and n - // Tensor tQpQ = make_tensor(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); - // Tensor tKVpKV = make_tensor(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); - - // Construct identity layout for sQ and sK - Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) - - // Repeat the partitioning with identity layouts - Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) - - // Allocate predicate tensors for k - Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); - Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); - - // Set predicates for k bounds - if (!Is_even_K) { - #pragma unroll - for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } - #pragma unroll - for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } - } - - // Prologue - - // Copy from Knew to K, optionally apply rotary embedding. - typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary; - auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont; - auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); - if constexpr (Append_KV) { - // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to - // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. - // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. - const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2); - Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), - Shape, Int>{}, - make_stride(params.rotary_dim / 2, _1{})); - Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), - Shape, Int>{}, - make_stride(params.rotary_dim / 2, _1{})); - Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), - Shape, Int>{}, - make_stride(params.rotary_dim / 2, _1{})); - Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), - Shape, Int>{}, - make_stride(params.rotary_dim / 2, _1{})); - Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); - Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); - Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); - Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); - // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); } - // if (cute::thread(8, 0)) { print_tensor(gCos); } - // if (cute::thread(0, 0)) { print_tensor(tRgCos); } - - const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) - + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; - const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) - + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; - // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, - // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. - // This maps to accessing the first 64 rows of knew_ptr. - Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) - + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), - Shape, Int>{}, - make_stride(params.knew_row_stride, _1{})); - // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); } - Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) - + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), - Shape, Int>{}, - make_stride(params.vnew_row_stride, _1{})); - Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) - Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) - - const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); - auto tKgK_data = tKgK.data(); - auto tVgV_data = tVgV.data(); - for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) { - pytorch_flash::copy_w_min_idx( - tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN - ); - tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); - if (params.rotary_dim == 0) { - pytorch_flash::copy_w_min_idx( - tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN - ); - } else { - if (params.is_rotary_interleaved) { - // Don't clear OOB_K because we're writing to global memory - pytorch_flash::copy_rotary_interleaved( - tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, - binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim - ); - tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2)); - tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2)); - } else { - // Don't clear OOB_K because we're writing to global memory - pytorch_flash::copy_rotary_contiguous( - tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, - binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim - ); - tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2)); - tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2)); - - } - } - tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); - if (block_table == nullptr) { - tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - } else { - if (n_block > n_block_copy_min) { - const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; - const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; - const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; - const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; - const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur]; - const int offset_diff = block_table_offset_next - block_table_offset_cur; - tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride; - tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride; - } - } - } - // Need this before we can read in K again, so that we'll see the updated K values. - __syncthreads(); - tKgK.data() = tKgK_data; - tVgV.data() = tVgV_data; - } - - // Read Q from gmem to smem, optionally apply rotary embedding. - if (!Append_KV || params.rotary_dim == 0) { - // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - pytorch_flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, - binfo.actual_seqlen_q - m_block * kBlockM); - } else { - const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); - // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. - // We do this by setting the row stride of gCos / gSin to 0. - Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), - Shape, Int>{}, - make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); - Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), - Shape, Int>{}, - make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); - Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), - Shape, Int>{}, - make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); - Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), - Shape, Int>{}, - make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); - Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); - Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); - Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); - Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); - if (params.is_rotary_interleaved) { - pytorch_flash::copy_rotary_interleaved( - tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, - 0, params.d, params.rotary_dim - ); - } else { - pytorch_flash::copy_rotary_contiguous( - tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, - 0, params.d, params.rotary_dim - ); - } - } - - int n_block = n_block_max - 1; - // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. - pytorch_flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, - binfo.actual_seqlen_k - n_block * kBlockN); - cute::cp_async_fence(); - - // pytorch_flash::cp_async_wait<0>(); - // __syncthreads(); - // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); } - // __syncthreads(); - - clear(acc_o); - - pytorch_flash::Softmax<2 * size<1>(acc_o)> softmax; - - const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; - pytorch_flash::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); - - // For performance reason, we separate out two kinds of iterations: - // those that need masking on S, and those that don't. - // We need masking on S for the very last block when K and V has length not multiple of kBlockN. - // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. - // We will have at least 1 "masking" iteration. - - // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to - // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. - constexpr int n_masking_steps = (!Is_causal && !Is_local) - ? 1 - : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); - #pragma unroll - for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { - Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) - clear(acc_s); - pytorch_flash::cp_async_wait<0>(); - __syncthreads(); - - // Advance gV - if (masking_step > 0) { - if (block_table == nullptr) { - tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - } else { - const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size; - const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size; - const int block_table_idx_next = n_block * kBlockN / params.page_block_size; - const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; - tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; - } - pytorch_flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); - } else { - // Clear the smem tiles to account for predicated off loads - pytorch_flash::copy( - gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN - ); - } - cute::cp_async_fence(); - - pytorch_flash::gemm( - acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, - smem_thr_copy_Q, smem_thr_copy_K - ); - // if (cute::thread0()) { print(acc_s); } - - mask.template apply_mask( - acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 - ); - - pytorch_flash::cp_async_wait<0>(); - __syncthreads(); - // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); } - // __syncthreads(); - - if (n_block > n_block_min) { - // Advance gK - if (block_table == nullptr) { - tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - } else { - const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; - const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; - const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; - const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; - tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; - } - pytorch_flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); - // This cp_async_fence needs to be in the if block, otherwise the synchronization - // isn't right and we get race conditions. - cute::cp_async_fence(); - } - - // We have key_padding_mask so we'll need to Check_inf - masking_step == 0 - ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) - : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); - // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } - - // Convert acc_s from fp32 to fp16/bf16 - Tensor rP = pytorch_flash::convert_type(acc_s); - // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); - - pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); - - // This check is at the end of the loop since we always have at least 1 iteration - if (n_masking_steps > 1 && n_block <= n_block_min) { - --n_block; - break; - } - } - - // These are the iterations where we don't need masking on S - for (; n_block >= n_block_min; --n_block) { - Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) - clear(acc_s); - pytorch_flash::cp_async_wait<0>(); - __syncthreads(); - // Advance gV - if (block_table == nullptr) { - tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - } else { - const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size; - const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size; - const int block_table_idx_next = n_block * kBlockN / params.page_block_size; - const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; - tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; - } - pytorch_flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); - cute::cp_async_fence(); - - pytorch_flash::gemm( - acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, - smem_thr_copy_Q, smem_thr_copy_K - ); - - pytorch_flash::cp_async_wait<0>(); - __syncthreads(); - if (n_block > n_block_min) { - // Advance gK - if (block_table == nullptr) { - tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - } else { - const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; - const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; - const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; - const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; - tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; - } - pytorch_flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); - // This cp_async_fence needs to be in the if block, otherwise the synchronization - // isn't right and we get race conditions. - cute::cp_async_fence(); - } - - mask.template apply_mask( - acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 - ); - softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); - - Tensor rP = pytorch_flash::convert_type(acc_s); - // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); - - pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); - } - - // Epilogue - - Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); - // if (cute::thread0()) { print(lse); } - - Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) - // Partition sO to match the accumulator partitioning - using SmemTiledCopyO = std::conditional_t< - !Split, - typename Kernel_traits::SmemCopyAtomO, - typename Kernel_traits::SmemCopyAtomOaccum - >; - auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma); - auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); - Tensor rO = pytorch_flash::convert_type(acc_o); - Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) - Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) - - // sOaccum is larger than sQ, so we need to syncthreads here - // TODO: allocate enough smem for sOaccum - if constexpr (Split) { __syncthreads(); } - - cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); - - const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) - + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; - const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q - + m_block * kBlockM) * params.d_rounded; - const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; - - Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), - Shape, Int>{}, - make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); - Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), - Shape>{}, Stride<_1>{}); - // if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n", row_offset_o, bidh, gOaccum.data()); } - - GmemTiledCopyO gmem_tiled_copy_Oaccum; - auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); - Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); - - __syncthreads(); - - Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); - cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); - - Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) - static_assert(decltype(size<0>(taccOcO))::value == 4); - // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. - Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); - CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M - if (get<1>(taccOcO_row(0)) == 0) { - #pragma unroll - for (int mi = 0; mi < size(lse); ++mi) { - const int row = get<0>(taccOcO_row(mi)); - if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } - } - } - - // Construct identity layout for sO - Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - // Repeat the partitioning with identity layouts - Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); - if (!Is_even_K) { - #pragma unroll - for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } - } - // Clear_OOB_K must be false since we don't want to write zeros to gmem - pytorch_flash::copy( - gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM - ); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void compute_attn(const Params ¶ms) { - const int m_block = blockIdx.x; - // The block index for the batch. - const int bidb = blockIdx.y; - // The block index for the head. - const int bidh = blockIdx.z; - - // We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting - // them to have the same number of threads or have to traverse the attention matrix - // in the same order. - // In the Philox RNG, we use the offset to store the batch, head, and the lane id - // (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within - // the attention matrix. This way, as long as we have the batch, head, and the location of - // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. - - pytorch_flash::compute_attn_1rowblock(params, bidb, bidh, m_block); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void compute_attn_splitkv(const Params ¶ms) { - const int m_block = blockIdx.x; - // The block index for the batch. - const int bidb = Split ? blockIdx.z / params.h : blockIdx.y; - // The block index for the head. - const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; - const int n_split_idx = Split ? blockIdx.y : 0; - const int num_n_splits = Split ? gridDim.y : 1; - pytorch_flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -constexpr T ceil_div(T numerator, T denominator) { - return (numerator + denominator - 1) / denominator; -} - - -template -inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { - using Element = typename Kernel_traits::Element; - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; - constexpr int kMaxSplits = 1 << Log_max_splits; - constexpr int kHeadDim = Kernel_traits::kHeadDim; - constexpr int kNThreads = Kernel_traits::kNThreads; - - static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128"); - static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32"); - static_assert(kNThreads == 128, "We assume that each block has 128 threads"); - - // Shared memory. - // kBlockM + 1 instead of kBlockM to reduce bank conflicts. - __shared__ ElementAccum sLSE[kMaxSplits][kBlockM + 1]; - - // The thread and block index. - const int tidx = threadIdx.x; - const int bidx = blockIdx.x; - - const index_t row_offset_lse = bidx * kBlockM; - Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lse), - Shape, Int>{}, - make_stride(params.b * params.h * params.seqlen_q, _1{})); - Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), - Shape>{}, Stride<_1>{}); - constexpr int kNLsePerThread = ceil_div(kMaxSplits * kBlockM, kNThreads); - // Read the LSE values from gmem and store them in shared memory, then tranpose them. - constexpr int kRowsPerLoadLSE = kNThreads / kBlockM; - #pragma unroll - for (int l = 0; l < kNLsePerThread; ++l) { - const int row = l * kRowsPerLoadLSE + tidx / kBlockM; - const int col = tidx % kBlockM; - ElementAccum lse = (row < params.num_splits && col < params.b * params.h * params.seqlen_q - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY; - if (row < kMaxSplits) { sLSE[row][col] = lse; } - // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); } - } - // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); } - __syncthreads(); - Tensor lse_accum = make_tensor(Shape>{}); - constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits); - // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits - // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads, - // kBlockM rows, so each time we load we can load 128 / kBlockM rows). - // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose; - // static_assert(kThreadsPerSplit <= 32); - static_assert(kRowsPerLoadTranspose <= 32); - static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits); - #pragma unroll - for (int l = 0; l < kNLsePerThread; ++l) { - const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; - const int col = tidx / kRowsPerLoadTranspose; - lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE[row][col] : -INFINITY; - // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); } - } - - // Compute the logsumexp of the LSE along the split dimension. - ElementAccum lse_max = lse_accum(0); - #pragma unroll - for (int l = 1; l < kNLsePerThread; ++l) { lse_max = max(lse_max, lse_accum(l)); } - MaxOp max_op; - lse_max = Allreduce::run(lse_max, max_op); - lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf - float lse_sum = expf(lse_accum(0) - lse_max); - #pragma unroll - for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); } - SumOp sum_op; - lse_sum = Allreduce::run(lse_sum, sum_op); - // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise - // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. - ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max; - // Calculate valid rows for this block - const int total_rows = params.b * params.h * params.seqlen_q; - const int local_row = tidx / kRowsPerLoadTranspose; - const int global_row = blockIdx.x * kBlockM + local_row; - - const bool is_reduction_writer = tidx % kRowsPerLoadTranspose == 0; - const bool is_valid_row = (local_row < kBlockM) && (global_row < total_rows); - - if (is_reduction_writer && is_valid_row) { - gLSE(local_row) = lse_logsum; - } - // Store the scales exp(lse - lse_logsum) in shared memory. - #pragma unroll - for (int l = 0; l < kNLsePerThread; ++l) { - const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; - const int col = tidx / kRowsPerLoadTranspose; - if (row < params.num_splits && col < kBlockM) { sLSE[row][col] = expf(lse_accum(l) - lse_logsum); } - } - __syncthreads(); - - const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded; - Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), - Shape, Int>{}, - Stride, _1>{}); - constexpr int kBlockN = kNThreads / kBlockM; - using GmemLayoutAtomOaccum = Layout, Int>, Stride, _1>>; - using GmemTiledCopyOaccum = decltype( - make_tiled_copy(Copy_Atom, ElementAccum>{}, - GmemLayoutAtomOaccum{}, - Layout>{})); // Val layout, 4 vals per store - GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; - auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); - Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); - Tensor tOrO = make_tensor(shape(tOgOaccum)); - Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); - clear(tOrO); - - // Predicates - Tensor cOaccum = make_identity_tensor(Shape, Int>{}); - // Repeat the partitioning with identity layouts - Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum); - Tensor tOpOaccum = make_tensor(make_shape(size<2>(tOgOaccum))); - if (!Is_even_K) { - #pragma unroll - for (int k = 0; k < size(tOpOaccum); ++k) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; } - } - // Load Oaccum in then scale and accumulate to O - for (int split = 0; split < params.num_splits; ++split) { - pytorch_flash::copy( - gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM - ); - #pragma unroll - for (int m = 0; m < size<1>(tOrOaccum); ++m) { - int row = get<0>(tOcOaccum(0, m, 0)); - ElementAccum lse_scale = sLSE[split][row]; - #pragma unroll - for (int k = 0; k < size<2>(tOrOaccum); ++k) { - #pragma unroll - for (int i = 0; i < size<0>(tOrOaccum); ++i) { - tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k); - } - } - // if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0], sLSE[split][1]); print(tOrOaccum); } - } - tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded; - } - // if (cute::thread0()) { print_tensor(tOrO); } - - Tensor rO = pytorch_flash::convert_type(tOrO); - // Write to gO - #pragma unroll - for (int m = 0; m < size<1>(rO); ++m) { - const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0)); - if (idx < params.b * params.h * params.seqlen_q) { - const int batch_idx = idx / (params.h * params.seqlen_q); - const int head_idx = (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q; - // The index to the rows of Q - const int row = idx - batch_idx * (params.h * params.seqlen_q) - head_idx * params.seqlen_q; - auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx * params.o_batch_stride - + head_idx * params.o_head_stride + row * params.o_row_stride; - #pragma unroll - for (int k = 0; k < size<2>(rO); ++k) { - if (Is_even_K || tOpOaccum(k)) { - const int col = get<1>(tOcOaccum(0, m, k)); - Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col), - Shape(rO))::value>>{}, Stride<_1>{}); - // TODO: Should check if this is using vectorized store, but it seems pretty fast - copy(rO(_, m, k), gO); - // if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d, batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx, batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); } - // reinterpret_cast(o_ptr)[col / 4] = recast(rO)(0, m, k); - } - } - } - } -} - -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h deleted file mode 100644 index 93e183542f6..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h +++ /dev/null @@ -1,378 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include - -#include -#include -#include - -namespace pytorch_flash { - -// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -#define ARCH_SUPPORTS_FLASH -#endif - -#if defined(ARCH_SUPPORTS_FLASH) && defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 11 && \ - defined(__CUDACC_VER_MINOR__) && __CUDACC_VER_MINOR__ >= 8 -#define KERNEL_PARAM_MODIFIER __grid_constant__ -#else -#define KERNEL_PARAM_MODIFIER -#endif - -// Define a macro for unsupported architecture handling to centralize the error message -#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!"); - -// Use a macro to clean up kernel definitions -#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \ -template \ -__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) - -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax) { - #if defined(ARCH_SUPPORTS_FLASH) - static_assert(!(Is_causal && Is_local)); // Enforce constraints - pytorch_flash::compute_attn(params); - #else - FLASH_UNSUPPORTED_ARCH - #endif -} - -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV) { - #if defined(ARCH_SUPPORTS_FLASH) - pytorch_flash::compute_attn_splitkv(params); - #else - FLASH_UNSUPPORTED_ARCH - #endif -} - -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) { - static_assert(Log_max_splits >= 1); - pytorch_flash::combine_attn_seqk_parallel(params); -} - -template -void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr size_t smem_size = Kernel_traits::kSmemSize; - // printf("smem_size = %d\n", smem_size); - - // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. - // https://github.com/kokkos/kokkos-kernels/issues/349 - // https://github.com/HazyResearch/flash-attention/issues/21 - - const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; - dim3 grid(num_m_block, params.b, params.h); - const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; - const bool is_even_K = params.d == Kernel_traits::kHeadDim; - const bool return_softmax = params.p_ptr != nullptr; - BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { - EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { - LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { - BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { - ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { - // Will only return softmax if dropout, to reduce compilation time. - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If return_softmax, set IsEvenMNConst to false to reduce number of templates - // If head dim > 128, set IsEvenMNConst to false to reduce number of templates - // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_kernel; - // auto kernel = &flash_fwd_kernel; - // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); - // auto kernel = &flash_fwd_kernel; - if (smem_size >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - // int ctas_per_sm; - // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); - // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); - }); - }); - }); - }); -} - -template -void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { - static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs"); - static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem"); - constexpr size_t smem_size = Kernel_traits::kSmemSize; - const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; - dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h); - const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; - const bool is_even_K = params.d == Kernel_traits::kHeadDim; - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { - EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { - LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { - BOOL_SWITCH(params.num_splits > 1, Split, [&] { - BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { - ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { - // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_splitkv_kernel; - // auto kernel = &flash_fwd_splitkv_kernel; - // auto kernel = &flash_fwd_splitkv_kernel; - if (smem_size >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); - }); - }); - }); - }); - }); - }); - if (params.num_splits > 1) { - // We want kBlockM to be as small as possible for more parallelism. - // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4. - // If headdim is divisible by 64, then we set kBlockM = 8, etc. - constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); - dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); - EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { - if (params.num_splits <= 2) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 4) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 8) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 16) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 32) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 64) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 128) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); - } -} - -template -void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr static int kBlockM = 64; // Fixed for all head dimensions - // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, - // and for headdim 192 with block size 64 x 128. - // Also for headdim 160 with block size 64 x 128 after the rotary addition. - constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); - run_flash_splitkv_fwd>(params, stream); -} - -template -void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr static int Headdim = 32; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - }); - }); -} - -template -void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr static int Headdim = 64; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - if constexpr(!Is_dropout) { - // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower - // Using block size (64 x 256) is 27% slower for seqlen=2k - // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - }); - }); -} - -template -void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr static int Headdim = 96; - auto dprops = at::cuda::getCurrentDeviceProperties(); - bool is_sm8x = dprops->major == 8 && dprops->minor > 0; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), - if (is_sm8x) { - if constexpr(!Is_causal) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // These two are always slower - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - }); - }); -} - -template -void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr static int Headdim = 128; - auto dprops = at::cuda::getCurrentDeviceProperties(); - bool is_sm8x = dprops->major == 8 && dprops->minor > 0; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - if constexpr(!Is_dropout) { - // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), - // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. - if (is_sm8x) { - if constexpr(!Is_causal) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // 1st ones are good for H100, A100 - // 2nd one is good for A6000 bc we get slightly better occupancy - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - }); - }); -} - -template -void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr static int Headdim = 160; - auto dprops = at::cuda::getCurrentDeviceProperties(); - bool is_sm8x = dprops->major == 8 && dprops->minor > 0; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - // For A100, H100, 128 x 32 is the fastest. - // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), - // and 128 x 64 with 8 warps is the fastest for non-causal. - if (is_sm8x) { - if constexpr(!Is_causal) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - }); - }); -} - -template -void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr static int Headdim = 192; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - if constexpr(!Is_dropout) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - }); - }); -} - -template -void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr static int Headdim = 224; - int device; - cudaGetDevice(&device); - int max_smem_per_block; - cudaError status_ = cudaDeviceGetAttribute( - &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); - if (status_ != cudaSuccess) { - C10_CUDA_CHECK(status_); - } - // printf("max_smem_per_block = %d\n", max_smem_per_block); - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32. - // If we have N = 32, there are only 1024 elements to load at once, where each load - // is 8 elements. This means we can only use 128 threads and not 256 threads. - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - }); - }); -} - -template -void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr static int Headdim = 256; - int device; - cudaGetDevice(&device); - int max_smem_per_sm, max_smem_per_block; - cudaError status_ = cudaDeviceGetAttribute( - &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device); - status_ = cudaDeviceGetAttribute( - &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); - if (status_ != cudaSuccess) { - C10_CUDA_CHECK(status_); - } - // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - // For A100, we want to run with 128 x 64 (128KB smem). - // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. - if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - // 64 KB - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // 96 KB - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - }); - }); -} - -}; // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernel_traits.h b/aten/src/ATen/native/transformers/cuda/flash_attn/kernel_traits.h deleted file mode 100644 index f4ff1270fad..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernel_traits.h +++ /dev/null @@ -1,347 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include - -#include -#include -#include - -namespace pytorch_flash{ - -using namespace cute; - -template -struct Flash_kernel_traits { - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - using Element = elem_type; - static constexpr bool Has_cp_async = true; -#else - using Element = cutlass::half_t; - static constexpr bool Has_cp_async = false; -#endif - - using ElementAccum = float; - using index_t = int64_t; - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - using MMA_Atom_Arch = std::conditional_t< - std::is_same_v, - MMA_Atom, - MMA_Atom - >; -#else - using MMA_Atom_Arch = MMA_Atom; -#endif - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 - using SmemCopyAtom = Copy_Atom; - using SmemCopyAtomTransposed = Copy_Atom; -#else - using SmemCopyAtom = Copy_Atom; - using SmemCopyAtomTransposed = Copy_Atom; -#endif -}; - -// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true -template > -struct Flash_fwd_kernel_traits : public Base { - using Element = typename Base::Element; - using ElementAccum = typename Base::ElementAccum; - using index_t = typename Base::index_t; - static constexpr bool Has_cp_async = Base::Has_cp_async; - using SmemCopyAtom = typename Base::SmemCopyAtom; - using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; - - static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; - static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; - - // The number of threads. - static constexpr int kNWarps = kNWarps_; - static constexpr int kNThreads = kNWarps * 32; - - static constexpr int kBlockM = kBlockM_; - static constexpr int kBlockN = kBlockN_; - static constexpr int kHeadDim = kHeadDim_; - static_assert(kHeadDim % 32 == 0); - static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; - static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); - static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; - - using TiledMma = TiledMMA< - typename Base::MMA_Atom_Arch, - Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group - Tile, _16, _16>>; - - using SmemLayoutAtomQ = decltype( - composition(Swizzle{}, - // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 - Layout>, - Stride, _1>>{})); - using SmemLayoutQ = decltype(tile_to_shape( - SmemLayoutAtomQ{}, - Shape, Int>{})); - - using SmemLayoutKV = decltype(tile_to_shape( - SmemLayoutAtomQ{}, - Shape, Int>{})); - - // https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434 - using SmemLayoutVtransposed = decltype( - composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); - using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); - - using SmemLayoutAtomO = decltype( - composition(Swizzle{}, - Layout, Int>, - Stride, _1>>{})); - using SmemLayoutO = decltype(tile_to_shape( - SmemLayoutAtomO{}, - Shape, Int>{})); - using SmemCopyAtomO = Copy_Atom, Element>; - using SmemCopyAtomOaccum = Copy_Atom, ElementAccum>; - - static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); - static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); - static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; - - static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); - // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. - // For example, for d=128, smem is split into 2 "pages", each page takes care of columns - // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, - // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, - // to the same banks. - static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; - static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); - using GmemLayoutAtom = Layout, Int>, - Stride, _1>>; - - // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading - // from the same address by the same threadblock. This is slightly faster. - using Gmem_copy_struct = std::conditional_t< - Has_cp_async, - SM80_CP_ASYNC_CACHEGLOBAL, - AutoVectorizingCopyWithAssumedAlignment<128> - >; - using GmemTiledCopyQKV = decltype( - make_tiled_copy(Copy_Atom{}, - GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per read - using GmemTiledCopyO = decltype( - make_tiled_copy(Copy_Atom, Element>{}, - GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per store - - using GmemLayoutAtomOaccum = std::conditional_t< - kBlockKSmem == 32, - Layout, // Thread layout, 8 threads per row - Stride< _8, _1>>, - Layout, // Thread layout, 16 threads per row - Stride< _16, _1>> - >; - using GmemTiledCopyOaccum = decltype( - make_tiled_copy(Copy_Atom, ElementAccum>{}, - GmemLayoutAtomOaccum{}, - Layout>{})); // Val layout, 4 vals per store - using GmemLayoutAtomRotcossin = GmemLayoutAtom; - using GmemTiledCopyRotcossin = decltype( - make_tiled_copy(Copy_Atom, Element>{}, - GmemLayoutAtomRotcossin{}, - Layout>{})); // Val layout, 4 vals per load - using GmemTiledCopyRotcossinCont = decltype( - make_tiled_copy(Copy_Atom, Element>{}, - GmemLayoutAtomRotcossin{}, - Layout>{})); // Val layout, 8 vals per load -}; - -// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. -// No_double_buffer is another option to reduce smem usage, but will slow things down. -template > -struct Flash_bwd_kernel_traits : public Base { - using Element = typename Base::Element; - using ElementAccum = typename Base::ElementAccum; - using index_t = typename Base::index_t; - static constexpr bool Has_cp_async = Base::Has_cp_async; - using SmemCopyAtom = typename Base::SmemCopyAtom; - using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; - - static constexpr bool Is_V_in_regs = Is_V_in_regs_; - static constexpr bool No_double_buffer = No_double_buffer_; - - // The number of threads. - static constexpr int kNWarps = kNWarps_; - static constexpr int kNThreads = kNWarps * 32; - - static constexpr int kBlockM = kBlockM_; - static constexpr int kBlockN = kBlockN_; - static constexpr int kHeadDim = kHeadDim_; - static_assert(kHeadDim % 32 == 0); - static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; - static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); - static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; - - static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_; - static_assert(kNWarps % AtomLayoutMSdP == 0); - static_assert(kNWarps % AtomLayoutNdKV == 0); - static_assert(kNWarps % AtomLayoutMdQ == 0); - - using TiledMmaSdP = TiledMMA< - typename Base::MMA_Atom_Arch, - Layout, Int, _1>>, - Tile, Int<16 * kNWarps / AtomLayoutMSdP>, _16>>; - - using TiledMmadKV = TiledMMA< - typename Base::MMA_Atom_Arch, - Layout, Int, _1>>, - Tile, Int<16 * kNWarps / AtomLayoutNdKV>, _16>>; - - using TiledMmadQ = TiledMMA< - typename Base::MMA_Atom_Arch, - Layout, Int, _1>>, // 2x4x1 or 4x2x1 thread group - Tile, Int<16 * kNWarps / AtomLayoutMdQ>, _16>>; - - using SmemLayoutAtomQdO = decltype( - composition(Swizzle{}, - Layout>, - Stride, _1>>{})); - using SmemLayoutQdO = decltype(tile_to_shape( - SmemLayoutAtomQdO{}, - make_shape(Int{}, Int{}))); - - using SmemLayoutAtomKV = decltype( - composition(Swizzle{}, - Layout, Int>, - Stride, _1>>{})); - using SmemLayoutKV = decltype(tile_to_shape( - // SmemLayoutAtomQdO{}, - SmemLayoutAtomKV{}, - make_shape(Int{}, Int{}))); - - using SmemLayoutKtransposed = decltype( - composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); - using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{})); - - // TODO: generalize to other values of kBlockN - // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 - // static constexpr int kPBlockN = kBlockN; - // Temporarily disabling this for hdim 256 on sm86 and sm89 - // static_assert(kBlockN >= 64); - static_assert(kBlockN >= 32); - // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. - static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32; - static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); - // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); - static constexpr int kSwizzlePdS = 3; - using SmemLayoutAtomPdS = decltype( - composition(Swizzle{}, - Layout, Int>, - Stride, _1>>{})); - using SmemLayoutPdS = decltype(tile_to_shape( - SmemLayoutAtomPdS{}, - make_shape(Int{}, Int{}))); - using SmemLayoutPdStransposed = decltype( - composition(SmemLayoutPdS{}, make_layout(Shape, Int>{}, GenRowMajor{}))); - using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{})); - - using SmemCopyAtomPdS = Copy_Atom, elem_type>; - - using SmemLayoutQdOtransposed = decltype( - composition(SmemLayoutQdO{}, make_layout(Shape, Int>{}, GenRowMajor{}))); - using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{})); - - using SmemLayoutAtomdKV = decltype( - composition(Swizzle{}, - Layout>, - Stride, _1>>{})); - using SmemLayoutdKV = decltype(tile_to_shape( - SmemLayoutAtomdKV{}, - make_shape(Int{}, Int{}))); - using SmemCopyAtomdKV = Copy_Atom, elem_type>; - - using SmemLayoutAtomdQ = decltype( - composition(Swizzle{}, - Layout>, - Stride, _1>>{})); - using SmemLayoutdQ = decltype(tile_to_shape( - SmemLayoutAtomdQ{}, - make_shape(Int{}, Int{}))); - using SmemCopyAtomdQ = Copy_Atom, elem_type>; - - // Double buffer for sQ - static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element); - static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); - static constexpr int kSmemdSSize = size(SmemLayoutPdS{}) * sizeof(Element); - static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element); - static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element); - static constexpr int kSmemSize = kSmemQdOSize - + (!Is_V_in_regs - ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) - : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); - static constexpr int kSmemSize1colblock = kSmemQdOSize - + (!Is_V_in_regs - ? kSmemKVSize + kSmemdSSize + kSmemPSize - : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); - - static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); - // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem - // to affect speed in practice. - static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; - static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); - using GmemLayoutAtom = Layout, Int>, - Stride, _1>>; - - // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading - // from the same address by the same threadblock. This is slightly faster. - using Gmem_copy_struct = std::conditional_t< - Has_cp_async, - SM80_CP_ASYNC_CACHEGLOBAL, - AutoVectorizingCopyWithAssumedAlignment<128> - >; - using GmemTiledCopyQKV = decltype( - make_tiled_copy(Copy_Atom{}, - GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per read - using GmemTiledCopydO = decltype( - make_tiled_copy(Copy_Atom, elem_type>{}, - GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per store - using GmemTiledCopydKV = decltype( - make_tiled_copy(Copy_Atom, elem_type>{}, - GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per store - using GmemTiledCopydQ = decltype( - make_tiled_copy(Copy_Atom, elem_type>{}, - GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per store - using GmemLayoutAtomdQaccum = std::conditional_t< - kBlockKSmem == 32, - Layout, // Thread layout, 8 threads per row - Stride< _8, _1>>, - Layout, // Thread layout, 16 threads per row - Stride< _16, _1>> - >; - using GmemTiledCopydQaccum = decltype( - make_tiled_copy(Copy_Atom, ElementAccum>{}, - GmemLayoutAtomdQaccum{}, - Layout>{})); // Val layout, 4 vals per store - - using GmemTiledCopydQaccumAtomicAdd = decltype( - make_tiled_copy(Copy_Atom, ElementAccum>{}, - Layout, // Thread layout, 8 threads per row - Stride<_32, _1>>{}, - Layout>{})); // Val layout, 1 val per store - -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim128_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim128_bf16_sm80.cu deleted file mode 100644 index 63a80c4d206..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim128_bf16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - -template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim128(params, stream); -} -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim128_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim128_fp16_sm80.cu deleted file mode 100644 index 720f54343a4..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim128_fp16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - -template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim128(params, stream); -} -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim160_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim160_bf16_sm80.cu deleted file mode 100644 index 04aa184a6f7..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim160_bf16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - -template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim160(params, stream); -} -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim160_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim160_fp16_sm80.cu deleted file mode 100644 index 97908216299..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim160_fp16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - -template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim160(params, stream); -} -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim192_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim192_bf16_sm80.cu deleted file mode 100644 index 76ac4426f03..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim192_bf16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - -template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim192(params, stream); -} -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim192_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim192_fp16_sm80.cu deleted file mode 100644 index d0a05f59721..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim192_fp16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - -template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim192(params, stream); -} -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim224_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim224_bf16_sm80.cu deleted file mode 100644 index 14ce1a9a450..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim224_bf16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - -template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim224(params, stream); -} -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim224_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim224_fp16_sm80.cu deleted file mode 100644 index 259c84cf8cd..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim224_fp16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - -template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim224(params, stream); -} -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim256_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim256_bf16_sm80.cu deleted file mode 100644 index 1767b60f790..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim256_bf16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - -template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim256(params, stream); -} -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim256_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim256_fp16_sm80.cu deleted file mode 100644 index 6381904f7b5..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim256_fp16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - -template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim256(params, stream); -} -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim32_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim32_bf16_sm80.cu deleted file mode 100644 index bd47a37e7f6..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim32_bf16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - -template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim32(params, stream); -} -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim32_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim32_fp16_sm80.cu deleted file mode 100644 index ae046260c37..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim32_fp16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - -template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim32(params, stream); -} -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim64_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim64_bf16_sm80.cu deleted file mode 100644 index 42314aac9d2..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim64_bf16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - -template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim64(params, stream); -} -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim64_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim64_fp16_sm80.cu deleted file mode 100644 index 616c784f752..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim64_fp16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - -template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim64(params, stream); -} -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_bf16_sm80.cu deleted file mode 100644 index 6eccc4f455a..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_bf16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - -template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim96(params, stream); -} -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_fp16_sm80.cu deleted file mode 100644 index 54e455b81a3..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_fp16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - -template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim96(params, stream); -} -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim128_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim128_bf16_sm80.cu deleted file mode 100644 index 99a95b8354b..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim128_bf16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim128(params, stream); -} -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim128_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim128_fp16_sm80.cu deleted file mode 100644 index 06a716e4073..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim128_fp16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim128(params, stream); -} -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim160_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim160_bf16_sm80.cu deleted file mode 100644 index d3edcc95523..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim160_bf16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim160(params, stream); -} -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim160_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim160_fp16_sm80.cu deleted file mode 100644 index a5af5a2d3f6..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim160_fp16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim160(params, stream); -} -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim192_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim192_bf16_sm80.cu deleted file mode 100644 index 2c937b2feba..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim192_bf16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim192(params, stream); -} -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim192_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim192_fp16_sm80.cu deleted file mode 100644 index df519ea6f3b..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim192_fp16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim192(params, stream); -} -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim224_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim224_bf16_sm80.cu deleted file mode 100644 index 39a0109016d..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim224_bf16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim224(params, stream); -} -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim224_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim224_fp16_sm80.cu deleted file mode 100644 index 1da19195461..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim224_fp16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim224(params, stream); -} -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim256_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim256_bf16_sm80.cu deleted file mode 100644 index 30e6e5b9fc8..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim256_bf16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim256(params, stream); -} -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim256_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim256_fp16_sm80.cu deleted file mode 100644 index 55036693f50..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim256_fp16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim256(params, stream); -} -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim32_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim32_bf16_sm80.cu deleted file mode 100644 index e4900bed280..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim32_bf16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim32(params, stream); -} -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim32_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim32_fp16_sm80.cu deleted file mode 100644 index 8134c2b4bb6..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim32_fp16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim32(params, stream); -} -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim64_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim64_bf16_sm80.cu deleted file mode 100644 index ffbb783a6c5..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim64_bf16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim64(params, stream); -} -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim64_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim64_fp16_sm80.cu deleted file mode 100644 index c109d79f8b3..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim64_fp16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim64(params, stream); -} -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim96_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim96_bf16_sm80.cu deleted file mode 100644 index 87f0e0f5f18..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim96_bf16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim96(params, stream); -} -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim96_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim96_fp16_sm80.cu deleted file mode 100644 index e21c1f4ecfa..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_hdim96_fp16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim96(params, stream); -} -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim128_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim128_bf16_sm80.cu deleted file mode 100644 index 80b9969a03c..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim128_bf16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - - -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim128_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim128_fp16_sm80.cu deleted file mode 100644 index 226c3b1e9b9..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim128_fp16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - - -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim160_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim160_bf16_sm80.cu deleted file mode 100644 index eed00f89c33..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim160_bf16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - - -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim160_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim160_fp16_sm80.cu deleted file mode 100644 index 4081ae3844a..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim160_fp16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - - -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim192_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim192_bf16_sm80.cu deleted file mode 100644 index f83bb87dd9a..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim192_bf16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - - -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim192_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim192_fp16_sm80.cu deleted file mode 100644 index 496b46ff11e..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim192_fp16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - - -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim224_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim224_bf16_sm80.cu deleted file mode 100644 index 89997121518..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim224_bf16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - - -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim224_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim224_fp16_sm80.cu deleted file mode 100644 index 691cd15e97d..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim224_fp16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - - -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim256_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim256_bf16_sm80.cu deleted file mode 100644 index 25d7543677c..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim256_bf16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - - -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim256_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim256_fp16_sm80.cu deleted file mode 100644 index 6ebce189407..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim256_fp16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - - -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim32_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim32_bf16_sm80.cu deleted file mode 100644 index 88a67a5110c..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim32_bf16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - - -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim32_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim32_fp16_sm80.cu deleted file mode 100644 index a7f7acc47dc..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim32_fp16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - - -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim64_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim64_bf16_sm80.cu deleted file mode 100644 index 2cb919bf2c7..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim64_bf16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - - -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim64_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim64_fp16_sm80.cu deleted file mode 100644 index b5b1fb5516f..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim64_fp16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - - -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim96_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim96_bf16_sm80.cu deleted file mode 100644 index e3fcb467773..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim96_bf16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - - -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim96_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim96_fp16_sm80.cu deleted file mode 100644 index 6621d0da07c..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_fwd_split_hdim96_fp16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ - -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" - -#include -namespace pytorch_flash{ - - -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py deleted file mode 100644 index 803c5390768..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py +++ /dev/null @@ -1,109 +0,0 @@ -# This file is run to generate the kernel instantiations for the flash_attn kernels -# They are written to several files in order to speed up compilation - -import argparse -import itertools -from dataclasses import dataclass -from pathlib import Path -from typing import Optional - - -DTYPE_MAP = { - "fp16": "cutlass::half_t", - "bf16": "cutlass::bfloat16_t", -} - -SM = [80] # Sm80 kernels support up to -HEAD_DIMENSIONS = [32, 64, 96, 128, 160, 192, 224, 256] -KERNEL_IMPL_TEMPLATE_FWD = """ -template<> -void run_mha_fwd_<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params ¶ms, cudaStream_t stream) {{ - run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream); -}} -""" -KERNEL_IMPL_TEMPLATE_FWD_SPLIT = """ - -template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params ¶ms, cudaStream_t stream); -""" - -KERNEL_IMPL_TEMPLATE_BWD = """ -template<> -void run_mha_bwd_<{DTYPE}, {HEAD_DIM}>(Flash_bwd_params ¶ms, cudaStream_t stream) {{ - run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream); -}} -""" - - -@dataclass -class Kernel: - sm: int - dtype: str - head_dim: int - direction: str - - @property - def template(self) -> str: - if self.direction == "fwd": - return KERNEL_IMPL_TEMPLATE_FWD.format( - DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim - ) - elif self.direction == "bwd": - return KERNEL_IMPL_TEMPLATE_BWD.format( - DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim - ) - else: - return KERNEL_IMPL_TEMPLATE_FWD_SPLIT.format( - DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim - ) - - @property - def filename(self) -> str: - return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}_sm{self.sm}.cu" - - -def get_all_kernels() -> list[Kernel]: - for dtype, head_dim, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SM): - for direction in ["fwd", "bwd", "fwd_split"]: - yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, direction=direction) - - -def write_kernel(kernel: Kernel, autogen_dir: Path) -> None: - prelude = """ -// Copyright (c) 2023, Tri Dao. - -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py"\n -""" - launch_template_str = kernel.direction if kernel.direction != "fwd_split" else "fwd" - include = f"#include \n" - namespace = "namespace pytorch_flash{\n" - namespace_end = "} // namespace pytorch_flash\n" - (autogen_dir / kernel.filename).write_text( - prelude + include + namespace + kernel.template + namespace_end - ) - - -def main(output_dir: Optional[str]) -> None: - if output_dir is None: - output_dir = Path(__file__).parent - else: - output_dir = Path(output_dir) - - for kernel in get_all_kernels(): - write_kernel(kernel, output_dir) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - prog="generate_kernels", - description="Generate the flash_attention kernels template instantiations", - ) - # Set an optional output directory - parser.add_argument( - "-o", - "--output_dir", - required=False, - help="Where to generate the kernels will default to the current directory", - ) - args = parser.parse_args() - main(args.output_dir) diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/mask.h b/aten/src/ATen/native/transformers/cuda/flash_attn/mask.h deleted file mode 100644 index 9cee154fbbd..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/mask.h +++ /dev/null @@ -1,213 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include - -namespace pytorch_flash { - -using namespace cute; - -template -__forceinline__ __device__ void apply_mask(Tensor &tensor, const int max_seqlen_k, - const int col_idx_offset_ = 0) { - // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) - static_assert(Layout::rank == 2, "Only support 2D Tensor"); - const int lane_id = threadIdx.x % 32; - const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; - #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; - #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - if (col_idx >= max_seqlen_k) { - // Without the "make_coord" we get wrong results - #pragma unroll - for (int mi = 0; mi < size<0>(tensor); ++mi) { - tensor(mi, make_coord(j, nj)) = -INFINITY; - } - } - } - } -} - -template -__forceinline__ __device__ void apply_mask_local(Tensor &tensor, const int col_idx_offset_, - const int max_seqlen_k, const int row_idx_offset, - const int max_seqlen_q, const int warp_row_stride, - const int window_size_left, const int window_size_right) { - // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) - static_assert(Layout::rank == 2, "Only support 2D Tensor"); - const int lane_id = threadIdx.x % 32; - const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; - #pragma unroll - for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const int row_idx_base = row_idx_offset + mi * warp_row_stride; - #pragma unroll - for (int i = 0; i < size<0, 0>(tensor); ++i) { - const int row_idx = row_idx_base + i * 8; - const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); - const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); - #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; - #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { - tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; - } - } - } - // if (cute::thread0()) { - // printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k); - // print(tensor(make_coord(i, mi), _)); - // // print(tensor(_, j + nj * size<1, 0>(tensor))); - // } - } - } -} - -template -__forceinline__ __device__ void apply_mask_causal(Tensor &tensor, const int col_idx_offset_, - const int max_seqlen_k, const int row_idx_offset, - const int max_seqlen_q, const int warp_row_stride) { - // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 - apply_mask_local(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset, - max_seqlen_q, warp_row_stride, -1, 0); -} - -template -__forceinline__ __device__ void apply_mask_causal_w_idx( - Tensor &tensor, Tensor const &idx_rowcol, - const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset) -{ - // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) - static_assert(Layout0::rank == 2, "Only support 2D Tensor"); - static_assert(Layout1::rank == 2, "Only support 2D Tensor"); - CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol)); - CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); - #pragma unroll - for (int mi = 0; mi < size<0>(tensor); ++mi) { - const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0))); - #pragma unroll - for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { - if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { - tensor(mi, ni) = -INFINITY; - } - } - // if (cute::thread0()) { - // printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k); - // print(tensor(_, make_coord(j, ni))); - // // print(tensor(_, j + ni * size<1, 0>(tensor))); - // } - } -} - -template -struct Mask { - - const int max_seqlen_k, max_seqlen_q; - const int window_size_left, window_size_right; - const float alibi_slope; - - __forceinline__ __device__ Mask(const int max_seqlen_k, const int max_seqlen_q, - const int window_size_left, const int window_size_right, - const float alibi_slope=0.f) - : max_seqlen_k(max_seqlen_k) - , max_seqlen_q(max_seqlen_q) - , window_size_left(window_size_left) - , window_size_right(window_size_right) - , alibi_slope(!Has_alibi ? 0.0 : alibi_slope) { - }; - - // Causal_mask: whether this particular iteration needs causal masking - template - __forceinline__ __device__ void apply_mask(Tensor &tensor_, - const int col_idx_offset_, - const int row_idx_offset, - const int warp_row_stride) { - static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local"); - static_assert(Layout::rank == 3, "Only support 3D Tensor"); - static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4"); - static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN; - // if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); } - if constexpr (Need_masking) { - // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor tensor = make_tensor(tensor_.data(), pytorch_flash::convert_layout_acc_rowcol(tensor_.layout())); - // Do we need both row and column indices, or just column incides? - static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask; - const int lane_id = threadIdx.x % 32; - const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; - if constexpr (Col_idx_only) { - #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; - #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - #pragma unroll - for (int mi = 0; mi < size<0>(tensor); ++mi) { - // No causal, no local - if constexpr (Has_alibi) { - tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; - } - if constexpr (!Is_even_MN) { - if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; } - } - } - } - } - } else { - #pragma unroll - for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const int row_idx_base = row_idx_offset + mi * warp_row_stride; - #pragma unroll - for (int i = 0; i < size<0, 0>(tensor); ++i) { - const int row_idx = row_idx_base + i * 8; - const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); - const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); - #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; - #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - if constexpr (Has_alibi) { - if constexpr (Is_causal) { - tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope * col_idx; - } else { - tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); - - } - } - if constexpr (Causal_mask) { - if (col_idx >= col_idx_limit_right) { - tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; - } - } - if constexpr (Is_local) { - if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) { - tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; - } - } - if constexpr (!Causal_mask && !Is_local && !Is_even_MN) { - // Causal and Local already handles MN masking - if (col_idx >= max_seqlen_k) { - tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; - } - } - } - } - } - } - } - } - }; - -}; - -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/rotary.h b/aten/src/ATen/native/transformers/cuda/flash_attn/rotary.h deleted file mode 100644 index 12dc1746c80..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/rotary.h +++ /dev/null @@ -1,152 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include - -#include - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace pytorch_flash { - -using namespace cute; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__forceinline__ __device__ void copy_rotary_interleaved(Tensor const &S, - Tensor &D, - Tensor const &Cos, - Tensor const &Sin, - Tensor const &identity_MN, - const int max_MN, const int min_MN, - const int dim, const int rotary_dim) { - CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); - CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K - CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K - static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2); - static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 - Tensor rCos = make_fragment_like(Cos); - Tensor rSin = make_fragment_like(Sin); - Tensor rS = make_fragment_like(S); - #pragma unroll - for (int m = 0; m < size<1>(S); ++m) { - if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { - #pragma unroll - for (int k = 0; k < size<2>(S); ++k) { - if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { - cute::copy(S(_, m, k), rS(_, m, k)); - if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { - cute::copy(Cos(_, m, k), rCos(_, m, k)); - cute::copy(Sin(_, m, k), rSin(_, m, k)); - Tensor S_fp32 = convert_type(rS(_, m, k)); - Tensor cos_fp32 = convert_type(rCos(_, m, k)); - Tensor sin_fp32 = convert_type(rSin(_, m, k)); - #pragma unroll - for (int i = 0; i < size<0>(rS) / 2; ++i) { - float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i); - float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i); - S_fp32(2 * i) = real; - S_fp32(2 * i + 1) = imag; - } - // Idk but I need to copy for the convert_type to work - Tensor S_fp32_copy = make_fragment_like(S_fp32); - cute::copy(S_fp32, S_fp32_copy); - using T = typename Engine0::value_type; - Tensor S_og_type = convert_type(S_fp32_copy); - cute::copy(S_og_type, rS(_, m, k)); - } - cute::copy(rS(_, m, k), D(_, m, k)); - } else if (Clear_OOB_K) { - cute::clear(D(_, m, k)); - } - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__forceinline__ __device__ void copy_rotary_contiguous(Tensor const &S, - Tensor &D, - Tensor const &Cos, - Tensor const &Sin, - Tensor const &identity_MN, - const int max_MN, const int min_MN, - const int dim, const int rotary_dim) { - CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); - CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA - CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); - static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 - Tensor rCos = make_fragment_like(Cos); - Tensor rSin = make_fragment_like(Sin); - Tensor rS = make_fragment_like(S); - Tensor rS_other = make_fragment_like(rS(_, 0, 0)); - #pragma unroll - for (int m = 0; m < size<1>(S); ++m) { - if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { - #pragma unroll - for (int k = 0; k < size<2>(S); ++k) { - if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { - cute::copy(S(_, m, k), rS(_, m, k)); - if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { - const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2; - Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout()); - cute::copy(gS_other, rS_other); - // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); } - Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout()); - Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout()); - cute::copy(gCos, rCos(_, m, k)); - cute::copy(gSin, rSin(_, m, k)); - // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); } - Tensor S_fp32 = convert_type(rS(_, m, k)); - Tensor S_other_fp32 = convert_type(rS_other); - Tensor cos_fp32 = convert_type(rCos(_, m, k)); - Tensor sin_fp32 = convert_type(rSin(_, m, k)); - #pragma unroll - for (int i = 0; i < size<0>(rS); ++i) { - S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i)); - } - // Idk but I need to copy for the convert_type to work - Tensor S_fp32_copy = make_fragment_like(S_fp32); - cute::copy(S_fp32, S_fp32_copy); - using T = typename Engine0::value_type; - Tensor S_og_type = convert_type(S_fp32_copy); - cute::copy(S_og_type, rS(_, m, k)); - // if (cute::thread0()) { print_tensor(rS(_, m, k)); } - } - cute::copy(rS(_, m, k), D(_, m, k)); - } else if (Clear_OOB_K) { - cute::clear(D(_, m, k)); - } - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/softmax.h b/aten/src/ATen/native/transformers/cuda/flash_attn/softmax.h deleted file mode 100644 index 9a9ae88b6cd..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/softmax.h +++ /dev/null @@ -1,186 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include - -#include - -#include - -#include -#include - -namespace pytorch_flash { - -using namespace cute; - -#define UNFUSE_FMA -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__device__ __forceinline__ void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { - static_assert(Layout0::rank == 2, "Only support 2D Tensor"); - static_assert(Layout1::rank == 1, "Only support 1D Tensor"); - CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); - #pragma unroll - for (int mi = 0; mi < size<0>(tensor); mi++) { - summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); - #pragma unroll - for (int ni = 1; ni < size<1>(tensor); ni++) { - summary(mi) = op(summary(mi), tensor(mi, ni)); - } - } -} - -template -__device__ __forceinline__ void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { - CUTE_STATIC_ASSERT_V(size(dst) == size(src)); - #pragma unroll - for (int i = 0; i < size(dst); i++){ - dst(i) = Allreduce<4>::run(src(i), op); - } -} - -template -__device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { - thread_reduce_(tensor, summary, op); - quad_allreduce_(summary, summary, op); -} - -template -__device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor &max){ - MaxOp max_op; - reduce_(tensor, max, max_op); -} - -template -__device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor &sum){ - SumOp sum_op; - thread_reduce_(tensor, sum, sum_op); -} - -// Apply the exp to all the elements. -template -__forceinline__ __device__ void scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { - static_assert(Layout0::rank == 2, "Only support 2D Tensor"); - static_assert(Layout1::rank == 1, "Only support 1D Tensor"); - CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); - #pragma unroll - for (int mi = 0; mi < size<0>(tensor); ++mi) { - // If max is -inf, then all elements must have been -inf (possibly due to masking). - // We don't want (-inf - (-inf)) since that would give NaN. - // If we don't have float around M_LOG2E the multiplication is done in fp64. - const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); - #pragma unroll - for (int ni = 0; ni < size<1>(tensor); ++ni) { - // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - // max * log_2(e)) This allows the compiler to use the ffma - // instruction instead of fadd and fmul separately. - #ifdef UNFUSE_FMA - tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled); - #else - tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); - #endif - } - } -} - -// Apply the exp to all the elements. -template -__forceinline__ __device__ void max_scale_exp2_sum(Tensor &tensor, Tensor &max, Tensor &sum, const float scale) { - static_assert(Layout0::rank == 2, "Only support 2D Tensor"); - static_assert(Layout1::rank == 1, "Only support 1D Tensor"); - CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); - #pragma unroll - for (int mi = 0; mi < size<0>(tensor); ++mi) { - MaxOp max_op; - max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); - #pragma unroll - for (int ni = 1; ni < size<1>(tensor); ni++) { - max(mi) = max_op(max(mi), tensor(mi, ni)); - } - max(mi) = Allreduce<4>::run(max(mi), max_op); - // If max is -inf, then all elements must have been -inf (possibly due to masking). - // We don't want (-inf - (-inf)) since that would give NaN. - const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; - sum(mi) = 0; - #pragma unroll - for (int ni = 0; ni < size<1>(tensor); ++ni) { - // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - // max * log_2(e)) This allows the compiler to use the ffma - // instruction instead of fadd and fmul separately. - tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); - sum(mi) += tensor(mi, ni); - } - SumOp sum_op; - sum(mi) = Allreduce<4>::run(sum(mi), sum_op); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Softmax { - - using TensorT = decltype(make_tensor(Shape>{})); - TensorT row_max, row_sum; - - __forceinline__ __device__ Softmax() {}; - - template - __forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) { - // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout())); - static_assert(decltype(size<0>(scores))::value == kNRows); - if (Is_first) { - pytorch_flash::template reduce_max(scores, row_max); - pytorch_flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); - pytorch_flash::reduce_sum(scores, row_sum); - } else { - Tensor scores_max_prev = make_fragment_like(row_max); - cute::copy(row_max, scores_max_prev); - pytorch_flash::template reduce_max(scores, row_max); - // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) - Tensor acc_o_rowcol = make_tensor(acc_o.data(), pytorch_flash::convert_layout_acc_rowcol(acc_o.layout())); - static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); - #pragma unroll - for (int mi = 0; mi < size(row_max); ++mi) { - float scores_max_cur = !Check_inf - ? row_max(mi) - : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); - float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); - row_sum(mi) *= scores_scale; - #pragma unroll - for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } - } - pytorch_flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); - // We don't do the reduce across threads here since we don't need to use the row_sum. - // We do that reduce at the end when we need to normalize the softmax. - pytorch_flash::reduce_sum(scores, row_sum); - } - }; - - template - __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) { - SumOp sum_op; - quad_allreduce_(row_sum, row_sum, sum_op); - TensorT lse = make_fragment_like(row_sum); - Tensor acc_o_rowcol = make_tensor(acc_o.data(), pytorch_flash::convert_layout_acc_rowcol(acc_o.layout())); - static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); - #pragma unroll - for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { - float sum = row_sum(mi); - float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; - lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); - float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; - #pragma unroll - for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } - } - return lse; - }; -}; - -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/utils.h b/aten/src/ATen/native/transformers/cuda/flash_attn/utils.h deleted file mode 100644 index 2c8add31836..00000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/utils.h +++ /dev/null @@ -1,394 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include -#include -#include - -#include - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -#include -#endif - -#include -#include - -#include -#include -#include -#include - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace pytorch_flash { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__forceinline__ __device__ uint32_t relu2(const uint32_t x); - -template<> -__forceinline__ __device__ uint32_t relu2(const uint32_t x) { - uint32_t res; - const uint32_t zero = 0u; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); -#else - asm volatile( \ - "{\n" \ - "\t .reg .f16x2 sela;\n" \ - "\t set.gtu.u32.f16x2 sela, %1, %2;\n" \ - "\t and.b32 %0, sela, %1;\n" - "}\n" : "=r"(res) : "r"(x), "r"(zero)); -#endif - return res; -} - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -template<> -__forceinline__ __device__ uint32_t relu2(const uint32_t x) { - uint32_t res; - const uint32_t zero = 0u; - asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); - return res; -} -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - -template -__forceinline__ __device__ uint32_t convert_relu2(const float2 x); - -template<> -__forceinline__ __device__ uint32_t convert_relu2(const float2 x) { - uint32_t res; - const uint32_t a = reinterpret_cast(x.x); - const uint32_t b = reinterpret_cast(x.y); - asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); - return res; -} - -template<> -__forceinline__ __device__ uint32_t convert_relu2(const float2 x) { - uint32_t res; - const uint32_t a = reinterpret_cast(x.x); - const uint32_t b = reinterpret_cast(x.y); - asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); - return res; -} - -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MaxOp { -__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } -}; - -template <> -struct MaxOp { -// This is slightly faster -__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct SumOp { -__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Allreduce { - static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); - template - static __device__ __forceinline__ T run(T x, Operator &op) { - constexpr int OFFSET = THREADS / 2; - x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); - return Allreduce::run(x, op); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -struct Allreduce<2> { -template -static __device__ __forceinline__ T run(T x, Operator &op) { - x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); - return x; -} -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__forceinline__ __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, - Tensor4 const& tCsB, TiledMma tiled_mma, - TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, - ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { - CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M - CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N - CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K - Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); - CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M - Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N - if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } - if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } - #pragma unroll - for (int i = 0; i < size<2>(tCrA); ++i) { - if (i < size<2>(tCrA) - 1) { - if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } - if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } - } - cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__forceinline__ __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, - TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, - ThrCopy smem_thr_copy_B) { - CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M - CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N - CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K - Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N - cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); - #pragma unroll - for (int i = 0; i < size<2>(tCrA); ++i) { - if (i < size<2>(tCrA) - 1) { - cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); - } - cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) -template -__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { - static_assert(decltype(size<0>(acc_layout))::value == 4); - static_assert(decltype(rank(acc_layout))::value == 3); - auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) - return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) -// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. -template -__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) { - using X = Underscore; - static_assert(decltype(size<0>(acc_layout))::value == 4); - static_assert(decltype(rank(acc_layout))::value == 3); - constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); - static_assert(mma_shape_K == 8 || mma_shape_K == 16); - if constexpr (mma_shape_K == 8) { - return acc_layout; - } else { - auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) - return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) -template -__forceinline__ __device__ auto convert_layout_acc_dropout(Layout acc_layout) { - using X = Underscore; - static_assert(decltype(size<0>(acc_layout))::value == 4); - static_assert(decltype(rank(acc_layout))::value == 3); - auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) - return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__forceinline__ __device__ auto convert_type(Tensor const &tensor) { - using From_type = typename Engine::value_type; - constexpr int numel = decltype(size(tensor))::value; - cutlass::NumericArrayConverter convert_op; - // HACK: this requires tensor to be "contiguous" - auto frag = convert_op(*reinterpret_cast *>(tensor.data())); - return make_tensor(make_rmem_ptr(&frag), tensor.layout()); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__forceinline__ __device__ void relu_(Tensor &tensor) { - constexpr int numel = decltype(size(tensor))::value; - static_assert(numel % 2 == 0); - using value_t = typename Engine::value_type; - // HACK: this requires tensor to be "contiguous" - Tensor tensor_uint32 = recast(tensor); - #pragma unroll - for (int i = 0; i < size(tensor_uint32); ++i) { - tensor_uint32(i) = relu2(tensor_uint32(i)); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction -template -__forceinline__ __device__ auto convert_type_relu(Tensor const &tensor) { - using From_type = typename Engine::value_type; - static_assert(std::is_same_v || std::is_same_v); - static_assert(std::is_same_v); - constexpr int numel = decltype(size(tensor))::value; - static_assert(numel % 2 == 0); -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - // HACK: this requires tensor to be "contiguous" - Tensor tensor_float2 = recast(tensor); - Tensor out_uint32 = make_tensor(tensor_float2.layout()); - #pragma unroll - for (int i = 0; i < size(out_uint32); ++i) { - out_uint32(i) = convert_relu2(tensor_float2(i)); - } - Tensor out = make_tensor(make_rmem_ptr(out_uint32.data()), tensor.layout()); -#else - Tensor out = pytorch_flash::convert_type(tensor); - pytorch_flash::relu_(out); -#endif - return out; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Blocks until all but N previous cp.async.commit_group operations have committed. -// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all -// (which is equivalent to commit_group then wait_group 0). -// Instead we just call cp.async.wait_group 0, which is slightly faster. -// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 -template -CUTE_HOST_DEVICE -void cp_async_wait() { -#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) - asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor const &S, - Tensor &D, Tensor const &identity_MN, - Tensor const &predicate_K, const int max_MN=0) { - CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); - CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K - // There's no case where !Clear_OOB_K && Clear_OOB_MN - static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); - #pragma unroll - for (int m = 0; m < size<1>(S); ++m) { - if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { - #pragma unroll - for (int k = 0; k < size<2>(S); ++k) { - if (Is_even_K || predicate_K(k)) { - cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); - } else if (Clear_OOB_K) { - cute::clear(D(_, m, k)); - } - } - } else if (Clear_OOB_MN) { - cute::clear(D(_, m, _)); - } - } - // TD [2023-04-13]: Strange that the code below can cause race condition. - // I think it's because the copies are under an if statement. - // if (Is_even_K) { - // #pragma unroll - // for (int m = 0; m < size<1>(S); ++m) { - // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { - // copy(tiled_copy, S(_, m, _), D(_, m, _)); - // } else if (Clear_OOB_MN) { - // clear(D(_, m, _)); - // } - // } - // } else { // It's slightly faster in this case if iterate over K first - // #pragma unroll - // for (int k = 0; k < size<2>(S); ++k) { - // if (predicate_K(k)) { - // #pragma unroll - // for (int m = 0; m < size<1>(S); ++m) { - // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { - // copy(tiled_copy, S(_, m, k), D(_, m, k)); - // } else if (Clear_OOB_MN) { - // clear(D(_, m, k)); - // } - // } - // } else if (Clear_OOB_K) { // There's no case where !Clear_OOB_K && Clear_OOB_MN - // if (Clear_OOB_MN || Is_even_MN) { - // clear(D(_, _, k)); - // } else { - // #pragma unroll - // for (int m = 0; m < size<1>(S); ++m) { - // if (!(Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN)) { - // clear(D(_, m, k)); - // } - // } - // } - // } - // } - // } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__forceinline__ __device__ void copy_w_min_idx(Tensor const &S, - Tensor &D, Tensor const &identity_MN, - Tensor const &predicate_K, - const int max_MN=0, const int min_MN=0) { - CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); - CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K - // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); } - #pragma unroll - for (int m = 0; m < size<1>(S); ++m) { - // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } - if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { - // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } - #pragma unroll - for (int k = 0; k < size<2>(S); ++k) { - if (Is_even_K || predicate_K(k)) { - cute::copy(S(_, m, k), D(_, m, k)); - } - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h index 10bfe248a76..fcef4200134 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h @@ -267,6 +267,7 @@ mha_fwd( bool is_causal, int window_size_left, int window_size_right, + const float softcap, const bool return_softmax, std::optional gen_) { #if defined(USE_CK_FLASH_ATTENTION) @@ -351,6 +352,7 @@ mha_varlen_fwd( bool is_causal, int window_size_left, int window_size_right, + const float softcap, const bool return_softmax, std::optional gen_) { #if defined(USE_CK_FLASH_ATTENTION) @@ -441,6 +443,7 @@ inline std::tuple mha_bwd( const bool is_causal, int window_size_left, int window_size_right, + const float softcap, const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset) { @@ -541,6 +544,7 @@ inline std::tuple mha_varlen_bwd const bool is_causal, int window_size_left, int window_size_right, + const float softcap, const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset) { diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 44b1d0213ee..54e486917f6 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -980,7 +980,16 @@ elseif(USE_CUDA) target_compile_definitions(torch_cuda PRIVATE USE_UCC) endif() if(USE_FLASH_ATTENTION) - target_compile_definitions(torch_cuda PRIVATE USE_FLASH_ATTENTION) + target_compile_definitions(torch_cuda PRIVATE + USE_FLASH_ATTENTION + FLASHATTENTION_DISABLE_ALIBI # Disable alibi attention as it's not currently used + FLASHATTENTION_DISABLE_SOFTCAP + FLASH_NAMESPACE=pytorch_flash + UNFUSE_FMA # Addressing issue #121558 + ) + target_include_directories(torch_cuda PRIVATE + ${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc/flash_attn/src/ + ) endif() if(USE_MEM_EFF_ATTENTION) target_compile_definitions(torch_cuda PRIVATE USE_MEM_EFF_ATTENTION) @@ -1372,7 +1381,13 @@ if(USE_ROCM) ${ROCM_SOURCE_DIR}/include/rccl/ ) if(USE_FLASH_ATTENTION) - target_compile_definitions(torch_hip PRIVATE USE_FLASH_ATTENTION) + target_compile_definitions(torch_hip PRIVATE + USE_FLASH_ATTENTION + FLASHATTENTION_DISABLE_ALIBI # Disable alibi attention as it's not currently used + FLASHATTENTION_DISABLE_SOFTCAP + FLASH_NAMESPACE=pytorch_flash + UNFUSE_FMA # Addressing issue #121558 + ) endif() if(USE_MEM_EFF_ATTENTION) target_compile_definitions(torch_hip PRIVATE USE_MEM_EFF_ATTENTION) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index fa77b906b1b..842870eed7c 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2873,17 +2873,17 @@ output_differentiability: [True, False, False, False] query, key, value, attn_bias: _scaled_dot_product_efficient_attention_backward(grad, query, key, value, attn_bias, output, log_sumexp, philox_seed, philox_offset, dropout_p, grad_input_mask, is_causal, scale) -- name: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) +- name: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor rng_state, Tensor unused, Tensor debug_attn_mask) output_differentiability: [True, False, False, False, False, False, False, False, False] - query, key, value: _scaled_dot_product_flash_attention_backward_symint(grad, query, key, value, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale) + query, key, value: _scaled_dot_product_flash_attention_backward_symint(grad, query, key, value, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, rng_state, unused, scale) - name: _scaled_dot_product_flash_attention_for_cpu(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor output, Tensor logsumexp) output_differentiability: [True, False] query, key, value: _scaled_dot_product_flash_attention_for_cpu_backward(grad, query, key, value, output, logsumexp, dropout_p, is_causal, attn_mask, scale) -- name: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) +- name: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor rng_state, Tensor unused, Tensor debug_attn_mask) output_differentiability: [True, False, False, False, False] - query, key, value: _flash_attention_backward_symint(grad, query, key, value, output, softmax_logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale, window_size_left, window_size_right) + query, key, value: _flash_attention_backward_symint(grad, query, key, value, output, softmax_logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, rng_state, unused, scale, window_size_left, window_size_right) - name: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt? max_seqlen_q, SymInt? max_seqlen_k, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? seqlen_k=None, int? window_size=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k) output_differentiability: [True, False, False, False, False, False] diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 7aec860ef21..1c866d1ab6e 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -5545,6 +5545,14 @@ def meta__scaled_dot_product_flash_attention( # capturing or not, but at the time of tracing we don't know if we # are going to use cudagraphs or not, so we return meta tensors here # it's possible we'll need to have some special handling in inductor for sdpa + # See [Note] BC breaking change to flash seed/offset + if torch.version.hip and torch.cuda.is_available(): + # Maintian old path on AMD + seed = torch.empty((), dtype=torch.long, device="meta") + offset = torch.empty((), dtype=torch.long, device="meta") + else: + seed = torch.empty((2), dtype=torch.uint64, device="meta") + offset = torch.empty((), dtype=torch.uint64, device="meta") return ( attention, @@ -5553,8 +5561,8 @@ def meta__scaled_dot_product_flash_attention( None, max_seqlen_batch_q, max_seqlen_batch_k, - torch.empty((), dtype=torch.long, device="meta"), - torch.empty((), dtype=torch.long, device="meta"), + seed, + offset, debug_mask, ) @@ -5878,11 +5886,17 @@ def meta__flash_attention_forward( # Cuda Path attention = torch.empty_like(query) - logsumexp = torch.empty( - (batch_size, num_heads, max_seqlen_batch_q), - dtype=torch.float, - device=query.device, - ) + if cum_seq_q is None: + logsumexp = torch.empty( + (batch_size, num_heads, max_seqlen_batch_q), + dtype=torch.float, + device=query.device, + ) + else: + total_q = query.size(0) + logsumexp = torch.empty( + (num_heads, total_q), dtype=torch.float, device=query.device + ) if return_debug_mask: blocksize_c = 128 if head_dim > 64 else 256 @@ -5899,12 +5913,21 @@ def meta__flash_attention_forward( else: debug_mask = torch.empty(0, dtype=query.dtype, device=query.device) - # See Note [Seed and Offset]: + # See Note [Seed and Offset] + # See [Note] BC breaking change to flash seed/offset + seed, offset = None, None + if torch.version.hip and torch.cuda.is_available(): + # Maintian old path on AMD + seed = torch.empty((), dtype=torch.long, device="meta") + offset = torch.empty((), dtype=torch.long, device="meta") + else: + seed = torch.empty((2), dtype=torch.uint64, device="meta") + offset = torch.empty((), dtype=torch.uint64, device="meta") return ( attention, logsumexp, - torch.empty((), dtype=torch.long, device="meta"), - torch.empty((), dtype=torch.long, device="meta"), + seed, + offset, debug_mask, ) diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h index 8bf3cd03c7d..95de9ac9d07 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h @@ -28,7 +28,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__embedding_bag_forward_only(Ate AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__embedding_bag_per_sample_weights_backward(AtenTensorHandle grad, AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, AtenTensorHandle offset2bag, int64_t mode, int64_t padding_idx, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fft_c2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t forward, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__flash_attention_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, int64_t* window_size_left, int64_t* window_size_right, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__flash_attention_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle rng_state, AtenTensorHandle unused, double* scale, int64_t* window_size_left, int64_t* window_size_right, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__flash_attention_forward(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* cum_seq_q, AtenTensorHandle* cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, int64_t* window_size_left, int64_t* window_size_right, AtenTensorHandle* seqused_k, AtenTensorHandle* alibi_slopes, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_moving_avg_obs_fq_helper(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5);