Update flash_attention kernel from 2.3.6 to 2.5.5 (#118935)

# Summary
Updates FlashAttention kernel code from tag [2.3.6](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.3.6) to [2.5.3](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.5.5).

The usual changes were then re-rellod on top of the modified kernel, changing how dropout saved for backward, removing the head_dim_pad since this would make the kernel inplace mutate and that has a bad interaction with functionalization.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118935
Approved by: https://github.com/cpuhrsch
This commit is contained in:
drisspg 2024-03-04 17:36:22 +00:00 committed by PyTorch MergeBot
parent d49864f6a5
commit 2e6c08a14b
42 changed files with 2117 additions and 2306 deletions

View File

@ -744,6 +744,12 @@ cmake_dependent_option(
Will be disabled if not supported by the platform" ON Will be disabled if not supported by the platform" ON
"USE_CUDA AND NOT MSVC" OFF) "USE_CUDA AND NOT MSVC" OFF)
# We are currenlty not using alibi attention for Flash
# So we disable this feature by default
# We dont currently document this feature because we don't
# Suspect users building from source will need this
add_definitions(-DFLASHATTENTION_DISABLE_ALIBI)
# CAVEAT: Again, do not check USE_ROCM here # CAVEAT: Again, do not check USE_ROCM here
# Flash Attention2 will error while building for sm52 while Mem Eff Attention won't # Flash Attention2 will error while building for sm52 while Mem Eff Attention won't
cmake_dependent_option( cmake_dependent_option(

View File

@ -50,7 +50,6 @@
#include <ATen/ops/scalar_tensor.h> #include <ATen/ops/scalar_tensor.h>
#include <ATen/ops/scaled_dot_product_attention.h> #include <ATen/ops/scaled_dot_product_attention.h>
#include <ATen/ops/split_native.h> #include <ATen/ops/split_native.h>
#include <ATen/ops/narrow_native.h>
#include <ATen/ops/zeros.h> #include <ATen/ops/zeros.h>
#endif #endif
@ -65,7 +64,6 @@
#include <ATen/native/transformers/attention.h> #include <ATen/native/transformers/attention.h>
#include <ATen/native/nested/NestedTensorUtils.h> #include <ATen/native/nested/NestedTensorUtils.h>
#include <ATen/native/nested/NestedTensorTransformerFunctions.h> #include <ATen/native/nested/NestedTensorTransformerFunctions.h>
#include <ATen/native/nested/NestedTensorUtils.h>
#include <ATen/native/transformers/cuda/sdp_utils.h> #include <ATen/native/transformers/cuda/sdp_utils.h>
#include <ATen/native/transformers/sdp_utils_cpp.h> #include <ATen/native/transformers/sdp_utils_cpp.h>
@ -852,6 +850,7 @@ _flash_attention_forward(
// of the tensor. This is useful for kv cache scenarios but for now // of the tensor. This is useful for kv cache scenarios but for now
// we will not support in this PR. // we will not support in this PR.
c10::optional<Tensor> seqused_k = c10::nullopt; c10::optional<Tensor> seqused_k = c10::nullopt;
c10::optional<Tensor> alibi_slopes = c10::nullopt;
// We are going to have two paths: // We are going to have two paths:
// 1. The standard MHA path for dense tensors // 1. The standard MHA path for dense tensors
@ -880,6 +879,7 @@ _flash_attention_forward(
cumulative_sequence_length_q.value(), cumulative_sequence_length_q.value(),
cumulative_sequence_length_k.value(), cumulative_sequence_length_k.value(),
seqused_k, /*seqused_k*/ seqused_k, /*seqused_k*/
alibi_slopes, /*alibi_slopes*/
max_seqlen_batch_q, max_seqlen_batch_q,
max_seqlen_batch_k, max_seqlen_batch_k,
dropout_p, dropout_p,
@ -905,6 +905,7 @@ _flash_attention_forward(
key, key,
value, value,
out, out,
alibi_slopes,
dropout_p, dropout_p,
softmax_scale, softmax_scale,
is_causal, is_causal,

View File

@ -1,3 +1,4 @@
#include <string_view>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <cstdint> #include <cstdint>
#include <type_traits> #include <type_traits>
@ -41,9 +42,8 @@
#include <ATen/native/transformers/cuda/mem_eff_attention/gemm_kernel_utils.h> #include <ATen/native/transformers/cuda/mem_eff_attention/gemm_kernel_utils.h>
#include <ATen/native/transformers/cuda/mem_eff_attention/pytorch_utils.h> #include <ATen/native/transformers/cuda/mem_eff_attention/pytorch_utils.h>
#endif #endif
namespace at {
namespace native { namespace at::native {
std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward( std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
const Tensor& grad_out, const Tensor& grad_out,
@ -74,6 +74,21 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
// The kernel computes irregardless we will drop for this functions return // The kernel computes irregardless we will drop for this functions return
Tensor grad_softmax; Tensor grad_softmax;
// Currently unused args:
c10::optional<at::Tensor> alibi_slopes{c10::nullopt};
bool determinisitic{false};
auto& ctx = at::globalContext();
if (ctx.deterministicAlgorithms()) {
if (ctx.deterministicAlgorithmsWarnOnly()) {
TORCH_WARN_ONCE(
"Flash Attention defaults to a non-deterministic algorithm. ",
"To explicitly enable determinism call torch.use_deterministic_algorithms(True, warn_only=False).");
} else {
determinisitic = true;
}
}
// We check the whether the cumulative_sequence_length_q is defined // We check the whether the cumulative_sequence_length_q is defined
// in order to determine whether we are using varlen or dense forward // in order to determine whether we are using varlen or dense forward
if (cumulative_sequence_length_q.defined()) { if (cumulative_sequence_length_q.defined()) {
@ -90,6 +105,7 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
dv, dv,
cumulative_sequence_length_q, cumulative_sequence_length_q,
cumulative_sequence_length_k, cumulative_sequence_length_k,
alibi_slopes,
max_seqlen_batch_q, max_seqlen_batch_q,
max_seqlen_batch_k, max_seqlen_batch_k,
dropout_p, dropout_p,
@ -98,6 +114,7 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
is_causal, is_causal,
-1, /*window_size_left*/ -1, /*window_size_left*/
-1, /*window_size_right*/ -1, /*window_size_right*/
determinisitic,
philox_seed, philox_seed,
philox_offset); philox_offset);
return std::make_tuple(dQuery, dKey, dValue); return std::make_tuple(dQuery, dKey, dValue);
@ -113,11 +130,13 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
dq, dq,
dk, dk,
dv, dv,
alibi_slopes,
dropout_p, dropout_p,
softmax_scale, softmax_scale,
is_causal, is_causal,
-1, /*window_size_left*/ -1, /*window_size_left*/
-1, /*window_size_right*/ -1, /*window_size_right*/
determinisitic,
philox_seed, philox_seed,
philox_offset); philox_offset);
return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue)); return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue));
@ -630,5 +649,4 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_e
grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2), grad_bias); grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2), grad_bias);
} }
} // namespace native } // namespace at::native
} // namespace at

View File

@ -0,0 +1,74 @@
#include <cmath>
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <ATen/native/transformers/cuda/flash_attn/utils.h>
namespace pytorch_flash {
using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool Is_causal>
struct Alibi {
const float alibi_slope;
const int max_seqlen_k, max_seqlen_q;
__forceinline__ __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q)
: alibi_slope(alibi_slope)
, max_seqlen_k(max_seqlen_k)
, max_seqlen_q(max_seqlen_q) {
};
template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
const int col_idx_offset_,
const int row_idx_offset,
const int warp_row_stride) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
}
}
}
} else { // Bias depends on both row_idx and col_idx
#pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) {
const int row_idx = row_idx_base + i * 8;
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
}
}
}
}
}
}
};
} // namespace pytorch_flash

View File

@ -24,12 +24,12 @@ struct BlockInfo {
} }
template <typename index_t> template <typename index_t>
inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { __forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
} }
template <typename index_t> template <typename index_t>
inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
} }

View File

@ -0,0 +1,96 @@
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
#include <ATen/native/transformers/cuda/flash_attn/philox.cuh>
#include <ATen/native/transformers/cuda/flash_attn/utils.h>
namespace pytorch_flash {
using namespace cute;
struct Dropout {
const unsigned long long seed, offset;
const uint8_t p_dropout_in_uint8_t;
__forceinline__ __device__ Dropout(const unsigned long long seed, const unsigned long long offset,
const uint8_t p_dropout_in_uint8_t,
const int bid, const int hid, const int tid, const int nheads)
: seed(seed)
, offset(offset + (bid * nheads + hid) * 32 + tid % 32)
, p_dropout_in_uint8_t(p_dropout_in_uint8_t) {
}
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
__forceinline__ __device__ void apply_dropout(Tensor<Engine, Layout> &tensor_,
int block_row_start, int block_col_start, int block_row_stride) {
// convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2)
Tensor tensor = make_tensor(tensor_.data(), pytorch_flash::convert_layout_acc_dropout(tensor_.layout()));
using T = typename Engine::value_type;
auto encode_dropout = [](bool keep, T val) {
return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0));
};
static_assert(decltype(size<2>(tensor))::value % 2 == 0);
const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t);
const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t);
// if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); }
#pragma unroll
for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) {
uint2 rowcol = make_uint2(block_row_start, block_col_start);
#pragma unroll
for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) {
// if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));}
uint4 random_uint4 = pytorch_flash::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
// if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}
uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
// Special implementation for 16-bit types: we duplicate the threshold to the
// low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction
// to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000,
// and the high 16 bits will be either 0xffff or 0x0000, depending on whether
// the random value is less than the threshold.
// We then do a bit-wise AND between the mask and the original value (in 32-bit).
// We're exploiting the fact that floating point comparison is equivalent to integer
// comparison, since we're comparing unsigned integers whose top 8-bits are zero.
if (!encode_dropout_in_sign_bit
&& (std::is_same<T, cutlass::half_t>::value || std::is_same<T, cutlass::bfloat16_t>::value)) {
uint16_t rnd_16[16];
#pragma unroll
for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); }
uint32_t (&rnd_32)[8] = reinterpret_cast<uint32_t (&)[8]>(rnd_16);
#pragma unroll
for (int j = 0; j < 2; j++) {
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
// if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); }
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
#pragma unroll
for (int i = 0; i < 4; i++) {
uint32_t mask;
asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t));
tensor_uint32(i) &= mask;
}
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
}
} else {
#pragma unroll
for (int j = 0; j < 2; j++) {
#pragma unroll
for (int i = 0; i < 8; i++) {
tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j));
}
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
}
}
// // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w);
// // }
}
}
}
};
} // namespace pytorch_flash

View File

@ -5,13 +5,15 @@
#pragma once #pragma once
#include <cuda.h> #include <cuda.h>
#include <vector>
#include <ATen/cuda/PhiloxUtils.cuh>
namespace pytorch_flash{
#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif
#include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
namespace pytorch_flash {
constexpr int TOTAL_DIM = 0; constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1; constexpr int H_DIM = 1;
constexpr int D_DIM = 2; constexpr int D_DIM = 2;
@ -19,7 +21,7 @@ constexpr int D_DIM = 2;
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
struct Qkv_params { struct Qkv_params {
using index_t = uint32_t; using index_t = int64_t;
// The QKV matrices. // The QKV matrices.
void *__restrict__ q_ptr; void *__restrict__ q_ptr;
void *__restrict__ k_ptr; void *__restrict__ k_ptr;
@ -96,7 +98,12 @@ struct Flash_fwd_params : public Qkv_params {
void * __restrict__ rotary_sin_ptr; void * __restrict__ rotary_sin_ptr;
// The indices to index into the KV cache. // The indices to index into the KV cache.
int *__restrict__ cache_batch_idx; int * __restrict__ cache_batch_idx;
// Paged KV cache
int * __restrict__ block_table;
index_t block_table_batch_stride;
int page_block_size;
// The dropout probability (probability of keeping an activation). // The dropout probability (probability of keeping an activation).
float p_dropout; float p_dropout;
@ -126,6 +133,9 @@ struct Flash_fwd_params : public Qkv_params {
bool is_rotary_interleaved; bool is_rotary_interleaved;
int num_splits; // For split-KV version int num_splits; // For split-KV version
void * __restrict__ alibi_slopes_ptr;
index_t alibi_slopes_batch_stride;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
@ -165,6 +175,9 @@ struct Flash_bwd_params : public Flash_fwd_params {
// The pointer to the softmax d sum. // The pointer to the softmax d sum.
void *__restrict__ dsoftmax_sum; void *__restrict__ dsoftmax_sum;
bool deterministic;
index_t dq_accum_split_stride;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
@ -172,7 +185,6 @@ struct Flash_bwd_params : public Flash_fwd_params {
template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream); template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream); template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream, const bool configure); template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
} // namespace pytorch_flash } // namespace pytorch_flash

View File

@ -1,29 +1,5 @@
/****************************************************************************** /******************************************************************************
* Copyright (c) 2022, Tri Dao. * Copyright (c) 2024, Tri Dao.
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/ ******************************************************************************/
#include <c10/core/ScalarType.h> #include <c10/core/ScalarType.h>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
@ -50,6 +26,7 @@
#include <ATen/ops/slice.h> #include <ATen/ops/slice.h>
#include <ATen/ops/narrow.h> #include <ATen/ops/narrow.h>
#include <ATen/ops/pad.h> #include <ATen/ops/pad.h>
#include <ATen/ops/zeros.h>
#endif #endif
@ -93,11 +70,11 @@ void set_params_fprop(Flash_fwd_params &params,
float p_dropout, float p_dropout,
float softmax_scale, float softmax_scale,
int window_size_left, int window_size_left,
int window_size_right) { int window_size_right,
bool seqlenq_ngroups_swapped=false) {
// Reset the parameters should be equivalent // Reset the parameters
params = {}; params = {};
// memset(&params, 0, sizeof(params));
params.is_bf16 = q.dtype() == at::kBFloat16; params.is_bf16 = q.dtype() == at::kBFloat16;
@ -121,6 +98,10 @@ void set_params_fprop(Flash_fwd_params &params,
params.k_batch_stride = k.stride(0); params.k_batch_stride = k.stride(0);
params.v_batch_stride = v.stride(0); params.v_batch_stride = v.stride(0);
params.o_batch_stride = out.stride(0); params.o_batch_stride = out.stride(0);
if (seqlenq_ngroups_swapped) {
params.q_batch_stride *= seqlen_q;
params.o_batch_stride *= seqlen_q;
}
} }
params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d); params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
@ -159,6 +140,9 @@ void set_params_fprop(Flash_fwd_params &params,
params.rp_dropout = 1.f / params.p_dropout; params.rp_dropout = 1.f / params.p_dropout;
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
TORCH_CHECK(p_dropout < 1.f); TORCH_CHECK(p_dropout < 1.f);
#ifdef FLASHATTENTION_DISABLE_DROPOUT
TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
#endif
// Causal is the special case where window_size_right == 0 and window_size_left < 0. // Causal is the special case where window_size_right == 0 and window_size_left < 0.
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
@ -169,7 +153,16 @@ void set_params_fprop(Flash_fwd_params &params,
params.window_size_left = window_size_left; params.window_size_left = window_size_left;
params.window_size_right = window_size_right; params.window_size_right = window_size_right;
#ifdef FLASHATTENTION_DISABLE_LOCAL
TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0),
"This flash attention build does not support local attention.");
#endif
params.is_seqlens_k_cumulative = true; params.is_seqlens_k_cumulative = true;
#ifdef FLASHATTENTION_DISABLE_UNEVEN_K
TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32.");
#endif
} }
void set_params_dgrad(Flash_bwd_params &params, void set_params_dgrad(Flash_bwd_params &params,
@ -202,7 +195,8 @@ void set_params_dgrad(Flash_bwd_params &params,
float p_dropout, float p_dropout,
float softmax_scale, float softmax_scale,
int window_size_left, int window_size_left,
int window_size_right) { int window_size_right,
bool deterministic) {
set_params_fprop(params, set_params_fprop(params,
b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded, b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
@ -244,11 +238,13 @@ void set_params_dgrad(Flash_bwd_params &params,
// Softmax sum // Softmax sum
params.dsoftmax_sum = dsoftmax_sum_d; params.dsoftmax_sum = dsoftmax_sum_d;
params.deterministic = deterministic;
} }
void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) { void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
FP16_SWITCH(!params.is_bf16, [&] { FP16_SWITCH(!params.is_bf16, [&] {
FWD_HEADDIM_SWITCH(params.d, [&] { HEADDIM_SWITCH(params.d, [&] {
if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
run_mha_fwd_<elem_type, kHeadDim>(params, stream); run_mha_fwd_<elem_type, kHeadDim>(params, stream);
} else { } else {
@ -300,16 +296,62 @@ inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n
return 1; return 1;
} }
void set_params_splitkv(Flash_fwd_params &params, const int batch_size,
const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q,
const int head_size_rounded, const float p_dropout,
const int num_splits, cudaDeviceProp *dprops, struct c10::TensorOptions opts) {
// This needs to match with run_mha_fwd_splitkv_dispatch
const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n;
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
// In any case we don't expect seqlen_q to be larger than 64 for inference.
const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64;
params.num_splits = num_splits;
if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
if (num_splits < 1) {
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
}
if (params.num_splits > 1) {
at::Tensor softmax_lse_accum = at::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
at::Tensor out_accum = at::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
params.oaccum_ptr = out_accum.data_ptr();
}
TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
}
}
void set_params_alibi(Flash_fwd_params &params, c10::optional<at::Tensor> &alibi_slopes_, int batch_size, int num_heads){
#ifdef FLASHATTENTION_DISABLE_ALIBI
TORCH_CHECK(!alibi_slopes_.has_value(), "This flash attention build does not support alibi.");
params.alibi_slopes_ptr = nullptr;
#else
if (alibi_slopes_.has_value()) {
auto alibi_slopes = alibi_slopes_.value();
TORCH_CHECK(alibi_slopes.dtype() == at::kFloat, "ALiBi slopes must have dtype fp32");
CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
TORCH_CHECK(alibi_slopes.sizes() == at::IntArrayRef({num_heads}) || alibi_slopes.sizes() == at::IntArrayRef({batch_size, num_heads}));
params.alibi_slopes_ptr = alibi_slopes.data_ptr();
params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
} else {
params.alibi_slopes_ptr = nullptr;
}
#endif
}
// return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p}; // return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p};
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout, const float p_dropout,
const float softmax_scale, const float softmax_scale,
bool is_causal, bool is_causal,
const int window_size_left, int window_size_left,
int window_size_right, int window_size_right,
const bool return_softmax, const bool return_softmax,
c10::optional<at::Generator> gen_) { c10::optional<at::Generator> gen_) {
@ -350,12 +392,16 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case if (window_size_left >= seqlen_k) { window_size_left = -1; }
if (window_size_right >= seqlen_k) { window_size_right = -1; }
// causal=true is the same as causal=false in this case
if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
if (is_causal) { window_size_right = 0; } if (is_causal) { window_size_right = 0; }
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza // H/t Daniel Haziza
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0; const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
at::Tensor temp_q = q; at::Tensor temp_q = q;
if (seqlenq_ngroups_swapped) { if (seqlenq_ngroups_swapped) {
const int ngroups = num_heads / num_heads_k; const int ngroups = num_heads / num_heads_k;
@ -369,9 +415,9 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
at::Tensor q_padded, k_padded, v_padded; at::Tensor q_padded, k_padded, v_padded;
q_padded = temp_q; q_padded = temp_q;
k_padded = k; k_padded = k;
v_padded = v; v_padded = v;
at::Tensor out; at::Tensor out;
if (out_.has_value()) { if (out_.has_value()) {
@ -423,30 +469,17 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
window_size_left, window_size_left,
window_size_right); window_size_right);
// This needs to match with run_mha_fwd_splitkv_dispatch
const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); set_params_splitkv(params, batch_size, num_heads,
const int num_n_blocks = (seqlen_k + block_n - 1) / block_n; head_size, seqlen_k, seqlen_q,
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts);
// In any case we don't expect seqlen_q to be larger than 64 for inference.
const int num_m_blocks = (seqlen_q + 64 - 1) / 64;
params.num_splits = 1;
if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
if (params.num_splits > 1) {
at::Tensor softmax_lse_accum = at::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
at::Tensor out_accum = at::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
params.oaccum_ptr = out_accum.data_ptr();
}
TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
}
// We want to checkpoint and save the RNG state for backward if dropout // We want to checkpoint and save the RNG state for backward if dropout
// We get the default generator and return the seed and offset which will // We get the default generator and return the seed and offset which will
// be used in the backward function // be used in the backward function
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
at::Tensor seed_t, offset_t; at::Tensor seed_t, offset_t;
if (p_dropout > 0.0) { if (p_dropout > 0.0) {
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
// number of times random will be generated per thread, to offset philox counter in thc random // number of times random will be generated per thread, to offset philox counter in thc random
// state // state
// We use a custom RNG that increases the offset by batch_size * nheads * 32. // We use a custom RNG that increases the offset by batch_size * nheads * 32.
@ -476,6 +509,8 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
} }
set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
if (seqlen_k > 0) { if (seqlen_k > 0) {
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
run_mha_fwd(params, stream); run_mha_fwd(params, stream);
@ -501,18 +536,18 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1 const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used. c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
const int max_seqlen_q, c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
int max_seqlen_q,
const int max_seqlen_k, const int max_seqlen_k,
const float p_dropout, const float p_dropout,
const float softmax_scale, const float softmax_scale,
const bool zero_tensors, const bool zero_tensors,
const bool is_causal, bool is_causal,
const int window_size_left, int window_size_left,
int window_size_right, int window_size_right,
const bool return_softmax, const bool return_softmax,
c10::optional<at::Generator> gen_) { c10::optional<at::Generator> gen_) {
if (is_causal) { window_size_right = 0; }
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5; // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
@ -544,17 +579,39 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const auto sizes = q.sizes(); const auto sizes = q.sizes();
const int total_q = sizes[0];
const int batch_size = cu_seqlens_q.numel() - 1; const int batch_size = cu_seqlens_q.numel() - 1;
const int num_heads = sizes[1]; int num_heads = sizes[1];
const int head_size_og = sizes[2]; const int head_size_og = sizes[2];
const int total_k = k.size(0); const int total_k = k.size(0);
const int num_heads_k = k.size(1); const int num_heads_k = k.size(1);
if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case
if (is_causal) { window_size_right = 0; }
void *cu_seqlens_q_d = cu_seqlens_q.data_ptr();
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
at::Tensor temp_q = q;
if (seqlenq_ngroups_swapped) {
const int ngroups = num_heads / num_heads_k;
temp_q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
max_seqlen_q = ngroups;
num_heads = num_heads_k;
cu_seqlens_q_d = nullptr;
}
const int total_q = q.sizes()[0];
TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
TORCH_CHECK(head_size_og % 8 == 0, "head_size must be a multiple of 8, this is ensured by padding!") TORCH_CHECK(head_size_og % 8 == 0, "head_size must be a multiple of 8, this is ensured by padding!")
if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
CHECK_SHAPE(q, total_q, num_heads, head_size_og); CHECK_SHAPE(q, total_q, num_heads, head_size_og);
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
@ -569,7 +626,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
} }
at::Tensor q_padded, k_padded, v_padded; at::Tensor q_padded, k_padded, v_padded;
q_padded = q; q_padded = temp_q;
k_padded = k; k_padded = k;
v_padded = v; v_padded = v;
@ -619,7 +676,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
num_heads, num_heads_k, num_heads, num_heads_k,
head_size, head_size_rounded, head_size, head_size_rounded,
q_padded, k_padded, v_padded, out, q_padded, k_padded, v_padded, out,
cu_seqlens_q.data_ptr(), cu_seqlens_q_d,
cu_seqlens_k.data_ptr(), cu_seqlens_k.data_ptr(),
seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr, seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
return_softmax ? p.data_ptr() : nullptr, return_softmax ? p.data_ptr() : nullptr,
@ -627,9 +684,16 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
p_dropout, p_dropout,
softmax_scale, softmax_scale,
window_size_left, window_size_left,
window_size_right); window_size_right,
seqlenq_ngroups_swapped);
if (seqlenq_ngroups_swapped) {
// Only apply split-k for decoding
set_params_splitkv(params, batch_size, num_heads,
head_size, max_seqlen_k, max_seqlen_q,
head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts);
}
// We want to checkpoint and save the RNG state for backward if dropout // We want to checkpoint and save the RNG state for backward if dropout
// We get the default generator and return the seed and offset which will // We get the default generator and return the seed and offset which will
// be used in the backward function // be used in the backward function
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
@ -664,31 +728,33 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
} }
auto stream = at::cuda::getCurrentCUDAStream().stream(); set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
run_mha_fwd(params, stream);
if (max_seqlen_k > 0) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
run_mha_fwd(params, stream);
} else {
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
out.zero_();
softmax_lse.fill_(std::numeric_limits<float>::infinity());
}
if (seqlenq_ngroups_swapped) {
std::array<int64_t, 4> size_before = {batch_size, max_seqlen_q, num_heads_k, head_size_og};
std::array<int64_t, 3> size_after = {batch_size, num_heads_k * max_seqlen_q, head_size_og};
out = out.reshape(size_before).transpose(1, 2).reshape(size_after);
q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * max_seqlen_q, 1});
}
return {out, q_padded, k_padded, v_padded, softmax_lse, seed_t, offset_t, p}; return {out, q_padded, k_padded, v_padded, softmax_lse, seed_t, offset_t, p};
} }
void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
FP16_SWITCH(!params.is_bf16, [&] { FP16_SWITCH(!params.is_bf16, [&] {
if (params.d <= 32) { HEADDIM_SWITCH(params.d, [&] {
run_mha_bwd_<elem_type, 32>(params, stream, configure); run_mha_bwd_<elem_type, kHeadDim>(params, stream);
} else if (params.d <= 64) { });
run_mha_bwd_<elem_type, 64>(params, stream, configure);
} else if (params.d <= 96) {
run_mha_bwd_<elem_type, 96>(params, stream, configure);
} else if (params.d <= 128) {
run_mha_bwd_<elem_type, 128>(params, stream, configure);
} else if (params.d <= 160) {
run_mha_bwd_<elem_type, 160>(params, stream, configure);
} else if (params.d <= 192) {
run_mha_bwd_<elem_type, 192>(params, stream, configure);
} else if (params.d <= 224) {
run_mha_bwd_<elem_type, 224>(params, stream, configure);
} else if (params.d <= 256) {
run_mha_bwd_<elem_type, 256>(params, stream, configure);
}
}); });
} }
@ -702,14 +768,19 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout, // probability to drop const float p_dropout, // probability to drop
const float softmax_scale, const float softmax_scale,
const bool is_causal, const bool is_causal,
const int window_size_left, int window_size_left,
int window_size_right, int window_size_right,
const bool deterministic,
const at::Tensor philox_seed, const at::Tensor philox_seed,
const at::Tensor philox_offset) { const at::Tensor philox_offset) {
#ifdef FLASHATTENTION_DISABLE_BACKWARD
TORCH_CHECK(false, "This flash attention build does not support backward.");
#endif
if (is_causal) { window_size_right = 0; } if (is_causal) { window_size_right = 0; }
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5; // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
@ -756,8 +827,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!"); TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!");
TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
if (head_size > 192) { if (head_size > 192 && (head_size <= 224 || is_dropout)) {
TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800"); TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800");
} }
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
@ -768,6 +839,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8"); TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
if (window_size_left >= seqlen_k) { window_size_left = -1; }
if (window_size_right >= seqlen_k) { window_size_right = -1; }
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
@ -803,8 +877,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
dv = at::empty_like(v); dv = at::empty_like(v);
} }
// const at::Tensor& dout_padded = dout;
// bool loop = seqlen_k > blocksize_c; // bool loop = seqlen_k > blocksize_c;
// TODO: change later, for now set to true for simplicity // TODO: change later, for now set to true for simplicity
bool loop = true; bool loop = true;
@ -818,9 +890,14 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
at::Tensor dq_accum; at::Tensor dq_accum;
at::Tensor dk_accum, dv_accum; at::Tensor dk_accum, dv_accum;
if (loop) { if (loop) {
dq_accum = at::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); if (!deterministic) {
// dk_accum = at::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat)); dq_accum = at::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
// dv_accum = at::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat)); } else {
const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num_heads);
dq_accum = at::zeros({nsplits, batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
}
// dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
// dv_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
} }
at::Tensor dk_expanded, dv_expanded; at::Tensor dk_expanded, dv_expanded;
@ -854,10 +931,11 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
p_dropout, p_dropout,
softmax_scale, softmax_scale,
window_size_left, window_size_left,
window_size_right); window_size_right,
deterministic);
params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
auto launch = &run_mha_bwd; auto launch = &run_mha_bwd;
// launch(params, stream, /*configure=*/true);
at::PhiloxCudaState philox_args; at::PhiloxCudaState philox_args;
if (is_dropout) { if (is_dropout) {
@ -872,12 +950,14 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
} }
params.philox_args = philox_args; params.philox_args = philox_args;
set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
if (seqlen_q > 0) { if (seqlen_q > 0) {
launch(params, stream, /*configure=*/false); launch(params, stream);
} else { } else {
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
dk.zero_(); dk_expanded.zero_();
dv.zero_(); dv_expanded.zero_();
softmax_d.zero_(); softmax_d.zero_();
} }
@ -901,17 +981,24 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1 const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
const int max_seqlen_q, const int max_seqlen_q,
const int max_seqlen_k, // max sequence length to choose the kernel const int max_seqlen_k, // max sequence length to choose the kernel
const float p_dropout, // probability to drop const float p_dropout, // probability to drop
const float softmax_scale, const float softmax_scale,
const bool zero_tensors, const bool zero_tensors,
const bool is_causal, const bool is_causal,
const int window_size_left, int window_size_left,
int window_size_right, int window_size_right,
const bool deterministic,
const at::Tensor philox_seed, const at::Tensor philox_seed,
const at::Tensor philox_offset) const at::Tensor philox_offset)
{ {
#ifdef FLASHATTENTION_DISABLE_BACKWARD
TORCH_CHECK(false, "This flash attention build does not support backward.");
#endif
if (is_causal) { window_size_right = 0; } if (is_causal) { window_size_right = 0; }
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5; // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
@ -925,7 +1012,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
auto q_dtype = q.dtype(); auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == at::kHalf|| q_dtype == at::kBFloat16, TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
"FlashAttention only support fp16 and bf16 data type"); "FlashAttention only support fp16 and bf16 data type");
if (q_dtype == at::kBFloat16) { if (q_dtype == at::kBFloat16) {
TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
@ -962,8 +1049,8 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!"); TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!");
TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
if (head_size > 192) { if (head_size > 192 && (head_size <= 224 || is_dropout)) {
TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800"); TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800");
} }
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
@ -974,6 +1061,9 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8"); TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
CHECK_SHAPE(q, total_q, num_heads, head_size); CHECK_SHAPE(q, total_q, num_heads, head_size);
CHECK_SHAPE(k, total_k, num_heads_k, head_size); CHECK_SHAPE(k, total_k, num_heads_k, head_size);
CHECK_SHAPE(v, total_k, num_heads_k, head_size); CHECK_SHAPE(v, total_k, num_heads_k, head_size);
@ -1008,11 +1098,9 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
CHECK_SHAPE(dv, total_k, num_heads_k, head_size); CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
} else { } else {
dv = at::empty_like(k); dv = at::empty_like(v);
} }
// const at::Tensor& dout_padded = dout;
// bool loop = max_seqlen_k > blocksize_c; // bool loop = max_seqlen_k > blocksize_c;
// TODO: change later, for now set to true for simplicity // TODO: change later, for now set to true for simplicity
bool loop = true; bool loop = true;
@ -1033,7 +1121,12 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
// cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will // cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will
// be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally // be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally
// allowed to do. So we won't have to do any bound checking, and performance should stay the same. // allowed to do. So we won't have to do any bound checking, and performance should stay the same.
dq_accum = at::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); if (!deterministic) {
dq_accum = at::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
} else {
const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num_heads);
dq_accum = at::zeros({nsplits, total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
}
} }
at::Tensor dk_expanded, dv_expanded; at::Tensor dk_expanded, dv_expanded;
@ -1072,10 +1165,11 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
p_dropout, p_dropout,
softmax_scale, softmax_scale,
window_size_left, window_size_left,
window_size_right); window_size_right,
deterministic);
params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
auto launch = &run_mha_bwd; auto launch = &run_mha_bwd;
// launch(params, stream, /*configure=*/true);
at::PhiloxCudaState philox_args; at::PhiloxCudaState philox_args;
if (is_dropout) { if (is_dropout) {
@ -1090,7 +1184,16 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
} }
params.philox_args = philox_args; params.philox_args = philox_args;
launch(params, stream, /*configure=*/false); set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
if (max_seqlen_q > 0) {
launch(params, stream);
} else {
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
dk_expanded.zero_();
dv_expanded.zero_();
softmax_d.zero_();
}
// For MQA/GQA we need to sum dK and dV across the groups // For MQA/GQA we need to sum dK and dV across the groups
if (num_heads_k != num_heads) { if (num_heads_k != num_heads) {
@ -1103,18 +1206,20 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
std::tuple<at::Tensor, at::Tensor> std::tuple<at::Tensor, at::Tensor>
mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
c10::optional<const at::Tensor> &seqlens_k_, // batch_size c10::optional<const at::Tensor> &seqlens_k_, // batch_size
c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2) c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2) c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
c10::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache c10::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
const float softmax_scale, const float softmax_scale,
bool is_causal, bool is_causal,
const int window_size_left, int window_size_left,
int window_size_right, int window_size_right,
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int num_splits int num_splits
@ -1143,25 +1248,41 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
at::Tensor block_table;
const bool paged_KV = block_table_.has_value();
if (paged_KV) {
TORCH_CHECK(!cache_batch_idx_.has_value(), "Paged KVcache does not support cache_batch_idx");
block_table = block_table_.value();
CHECK_DEVICE(block_table);
TORCH_CHECK(block_table.dtype() == at::kInt, "block_table must have dtype torch.int32");
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
}
const auto sizes = q.sizes(); const auto sizes = q.sizes();
const int batch_size = sizes[0]; const int batch_size = sizes[0];
int seqlen_q = sizes[1]; int seqlen_q = sizes[1];
int num_heads = sizes[2]; int num_heads = sizes[2];
const int head_size_og = sizes[3]; const int head_size_og = sizes[3];
const int seqlen_k = kcache.size(1);
const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
const int num_blocks = !paged_KV ? 0 : kcache.size(0);
const int page_block_size = !paged_KV ? 1 : kcache.size(1);
TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");
const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size;
const int num_heads_k = kcache.size(2); const int num_heads_k = kcache.size(2);
const int batch_size_c = kcache.size(0); const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size;
TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(batch_size > 0, "batch size must be postive");
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case // causal=true is the same as causal=false in this case
if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
if (is_causal) { window_size_right = 0; } if (is_causal) { window_size_right = 0; }
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza // H/t Daniel Haziza
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0; const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
if (seqlenq_ngroups_swapped) { if (seqlenq_ngroups_swapped) {
const int ngroups = num_heads / num_heads_k; const int ngroups = num_heads / num_heads_k;
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
@ -1169,9 +1290,18 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
num_heads = num_heads_k; num_heads = num_heads_k;
} }
if (window_size_left >= seqlen_k) { window_size_left = -1; }
if (window_size_right >= seqlen_k) { window_size_right = -1; }
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og); if (!paged_KV) {
CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og); CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
} else {
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
}
at::Tensor q_padded, kcache_padded, vcache_padded; at::Tensor q_padded, kcache_padded, vcache_padded;
if (head_size_og % 8 != 0) { if (head_size_og % 8 != 0) {
@ -1310,27 +1440,24 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
TORCH_CHECK(cache_batch_idx.scalar_type() == at::kInt, "cache_batch_idx must have dtype int32"); TORCH_CHECK(cache_batch_idx.scalar_type() == at::kInt, "cache_batch_idx must have dtype int32");
params.cache_batch_idx = reinterpret_cast<int *>(cache_batch_idx.data_ptr()); params.cache_batch_idx = reinterpret_cast<int *>(cache_batch_idx.data_ptr());
} }
// This needs to match with run_mha_fwd_splitkv_dispatch
const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); set_params_splitkv(params, batch_size, num_heads,
const int num_n_blocks = (seqlen_k + block_n - 1) / block_n; head_size, seqlen_k, seqlen_q,
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. head_size_rounded, /*dropout*/0.f, num_splits, dprops, opts);
// In any case we don't expect seqlen_q to be larger than 64 for inference.
const int num_m_blocks = (seqlen_q + 64 - 1) / 64; if (paged_KV) {
params.num_splits = num_splits; params.block_table = block_table.data_ptr<int>();
if (num_splits < 1) { params.block_table_batch_stride = block_table.stride(0);
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
}
TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
if (params.num_splits > 1) {
at::Tensor softmax_lse_accum = at::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
at::Tensor out_accum = at::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
params.oaccum_ptr = out_accum.data_ptr();
} }
params.page_block_size = page_block_size;
set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
// Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx,
run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value()); // or paged KV cache
run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value() || paged_KV);
if (head_size_og % 8 != 0) { if (head_size_og % 8 != 0) {
// out = out.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)}); // out = out.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)});
@ -1352,6 +1479,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
} }
return {out, softmax_lse}; return {out, softmax_lse};
} }
} // namespace pytorch_fmha } // namespace pytorch_fmha
#endif #endif

View File

@ -12,10 +12,11 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout, const float p_dropout,
const float softmax_scale, const float softmax_scale,
bool is_causal, bool is_causal,
const int window_size_left, int window_size_left,
int window_size_right, int window_size_right,
const bool return_softmax, const bool return_softmax,
c10::optional<at::Generator> gen_); c10::optional<at::Generator> gen_);
@ -28,13 +29,14 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1 const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used. c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
const int max_seqlen_q, c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
int max_seqlen_q,
const int max_seqlen_k, const int max_seqlen_k,
const float p_dropout, const float p_dropout,
const float softmax_scale, const float softmax_scale,
const bool zero_tensors, const bool zero_tensors,
const bool is_causal, bool is_causal,
const int window_size_left, int window_size_left,
int window_size_right, int window_size_right,
const bool return_softmax, const bool return_softmax,
c10::optional<at::Generator> gen_); c10::optional<at::Generator> gen_);
@ -50,11 +52,13 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout, // probability to drop const float p_dropout, // probability to drop
const float softmax_scale, const float softmax_scale,
const bool is_causal, const bool is_causal,
const int window_size_left, int window_size_left,
int window_size_right, int window_size_right,
const bool deterministic,
const at::Tensor philox_seed, const at::Tensor philox_seed,
const at::Tensor philox_offset); const at::Tensor philox_offset);
@ -70,14 +74,16 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1 const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
const int max_seqlen_q, const int max_seqlen_q,
const int max_seqlen_k, // max sequence length to choose the kernel const int max_seqlen_k, // max sequence length to choose the kernel
const float p_dropout, // probability to drop const float p_dropout, // probability to drop
const float softmax_scale, const float softmax_scale,
const bool zero_tensors, const bool zero_tensors,
const bool is_causal, const bool is_causal,
const int window_size_left, int window_size_left,
int window_size_right, int window_size_right,
const bool deterministic,
const at::Tensor philox_seed, const at::Tensor philox_seed,
const at::Tensor philox_offset); const at::Tensor philox_offset);

View File

@ -1,4 +1,6 @@
// Copyright (c) 2022, Tri Dao. /******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once #pragma once
@ -6,58 +8,81 @@
#include <ATen/native/transformers/cuda/flash_attn/static_switch.h> #include <ATen/native/transformers/cuda/flash_attn/static_switch.h>
#include <ATen/native/transformers/cuda/flash_attn/flash.h> #include <ATen/native/transformers/cuda/flash_attn/flash.h>
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_preprocess_kernel.h>
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h> #include <ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h>
namespace pytorch_flash { namespace pytorch_flash {
// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#define ARCH_SUPPORTS_FLASH
#define KERNEL_PARAM_MODIFIER __grid_constant__
#else
#define KERNEL_PARAM_MODIFIER
#endif
// Define a macro for unsupported architecture handling to centralize the error message
#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
// Use a macro to clean up kernel definitions
#define DEFINE_FLASH_BACKWARD_KERNEL(kernelName, ...) \
template<typename Kernel_traits, __VA_ARGS__> \
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params)
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K) {
#if defined(ARCH_SUPPORTS_FLASH)
pytorch_flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
pytorch_flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
template<bool Clear_dQaccum=true, typename Kernel_traits> template<bool Clear_dQaccum=true, typename Kernel_traits>
__global__ void flash_bwd_dot_do_o_kernel(Flash_bwd_params params) { __global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) {
pytorch_flash::compute_dot_do_o<Clear_dQaccum, Kernel_traits>(params); pytorch_flash::compute_dot_do_o<Clear_dQaccum, Kernel_traits>(params);
} }
template<typename Kernel_traits> template<typename Kernel_traits>
__global__ void flash_bwd_clear_dkvaccum_kernel(Flash_bwd_params params) { __global__ void flash_bwd_clear_dkvaccum_kernel(const Flash_bwd_params params) {
pytorch_flash::clear_dKVaccum<Kernel_traits>(params); pytorch_flash::clear_dKVaccum<Kernel_traits>(params);
} }
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_M, bool Is_even_K> template<typename Kernel_traits>
__global__ void flash_bwd_dq_dk_dv_loop_kernel(Flash_bwd_params params) { __global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) {
pytorch_flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K>(params); pytorch_flash::convert_dQ<Kernel_traits>(params, nsplits);
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K>
__global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(Flash_bwd_params params) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
pytorch_flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K>(params);
#else
printf("FATAL: FlashAttention requires to be build with sm80-sm90, but was built for < 8.0!");
#endif
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K>
__global__ void flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel(Flash_bwd_params params) {
pytorch_flash::compute_dq_dk_dv_seqq_parallel<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K>(params);
} }
template<typename Kernel_traits> template<typename Kernel_traits>
__global__ void flash_bwd_convert_dq_kernel(Flash_bwd_params params) { __global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) {
pytorch_flash::convert_dQ<Kernel_traits>(params);
}
template<typename Kernel_traits>
__global__ void flash_bwd_convert_dkv_kernel(Flash_bwd_params params) {
pytorch_flash::convert_dKV<Kernel_traits>(params); pytorch_flash::convert_dKV<Kernel_traits>(params);
} }
template<typename Kernel_traits, bool Is_dropout> template<typename Kernel_traits, bool Is_dropout>
void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream) {
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid_m(num_m_block, params.b, params.h); dim3 grid_m(num_m_block, params.b, params.h);
const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
dim3 grid_n(num_n_block, params.b, params.h); int gridDimx = num_n_block;
if (params.deterministic) {
auto dprops = at::cuda::getCurrentDeviceProperties();
gridDimx = (dprops->multiProcessorCount + params.b * params.h - 1) / (params.b * params.h);
}
dim3 grid_n(gridDimx, params.b, params.h);
flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params); if (!params.deterministic) {
flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
} else {
flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
}
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
// We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not // We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not
@ -66,21 +91,23 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
const bool is_even_K = params.d == Kernel_traits::kHeadDim; const bool is_even_K = params.d == Kernel_traits::kHeadDim;
constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock; constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock;
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv); // printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] { LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If Is_local, set Is_causal to false // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst>; // If Is_local, set Is_causal to false
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenMNConst, true>; auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst>;
if (smem_size_dq_dk_dv >= 48 * 1024) { // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
C10_CUDA_CHECK(cudaFuncSetAttribute( if (smem_size_dq_dk_dv >= 48 * 1024) {
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); C10_CUDA_CHECK(cudaFuncSetAttribute(
} kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params); }
C10_CUDA_KERNEL_LAUNCH_CHECK(); kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}); });
}); });
}); });
@ -91,58 +118,19 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
C10_CUDA_CHECK(cudaFuncSetAttribute( C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize)); kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
} }
kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params); kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params, !params.deterministic ? 1 : gridDimx);
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
} }
template<typename Kernel_traits, bool Is_dropout> template<typename Kernel_traits, bool Is_dropout>
void run_flash_bwd_seqq_parallel(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; #ifndef FLASHATTENTION_DISABLE_BACKWARD
dim3 grid_n(num_n_block, params.b, params.h_k); run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout>(params, stream);
flash_bwd_clear_dkvaccum_kernel<Kernel_traits><<<grid_n, Kernel_traits::kNThreads, 0, stream>>>(params); #endif
C10_CUDA_KERNEL_LAUNCH_CHECK();
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid_m(num_m_block, params.b, params.h);
// We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check
// for cu_seqlens_k as well.
const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1rowblock;
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst && IsEvenKConst, IsEvenKConst>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>;
if (smem_size_dq_dk_dv >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
}
kernel<<<grid_m, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
auto kernel_dkv = &flash_bwd_convert_dkv_kernel<Kernel_traits>;
if (Kernel_traits::kSmemKVSize >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel_dkv, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemKVSize));
}
kernel_dkv<<<grid_n, Kernel_traits::kNThreads, Kernel_traits::kSmemKVSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template<typename Kernel_traits, bool Is_dropout>
void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
if (configure) return;
run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout>(params, stream, configure);
} }
template<typename T> template<typename T>
void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 32; constexpr static int Headdim = 32;
int device; int device;
cudaGetDevice(&device); cudaGetDevice(&device);
@ -152,21 +140,21 @@ void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream, const boo
if (status_ != cudaSuccess) { if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_); C10_CUDA_CHECK(status_);
} }
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB
if constexpr(!Is_dropout) { // We can afford more registers to keep V in registers if constexpr(!Is_dropout) { // We can afford more registers to keep V in registers
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
} else { } else {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
} }
} else { // 96 KB } else { // 96 KB
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
} }
}); });
} }
template<typename T> template<typename T>
void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 64; constexpr static int Headdim = 64;
int device; int device;
cudaGetDevice(&device); cudaGetDevice(&device);
@ -177,42 +165,41 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream, const boo
C10_CUDA_CHECK(status_); C10_CUDA_CHECK(status_);
} }
// printf("max_smem_per_block = %d\n", max_smem_per_block); // printf("max_smem_per_block = %d\n", max_smem_per_block);
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// Changing AtomLayoutMdQ from 2 to 4 takes the same time // Changing AtomLayoutMdQ from 2 to 4 takes the same time
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 2, 4, 4, false, false, T>>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
// This is slightly faster. We want to split M more so we need fewer registers to store LSE. // This is slightly faster. We want to split M more so we need fewer registers to store LSE.
if (max_smem_per_block >= 144 * 1024) { if (max_smem_per_block >= 144 * 1024) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
// This has a lot of register spilling // This has a lot of register spilling
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
} else { } else {
// if (params.h == params.h_k) { // if (params.h == params.h_k) {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream);
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>, Is_dropout>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>, Is_dropout>(params, stream);
// } else { // } else {
// run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream, configure);
// } // }
} }
}); });
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, 2, 2, 2, true, false, T>>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, 2, 2, 2, true, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 4, 1, 4, 1, false, false, T>>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 4, 1, 4, 1, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 16, 128, 4, 1, 4, 1, false, false, T>>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 16, 128, 4, 1, 4, 1, false, false, T>>(params, stream);
// M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times // M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 2, 2, 2, false, T>>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 2, 2, 2, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 4, 4, 2, 4, false, false, T>>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 4, 4, 2, 4, false, false, T>>(params, stream);
} }
template<typename T> template<typename T>
void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 96; constexpr static int Headdim = 96;
int device; int device;
cudaGetDevice(&device); cudaGetDevice(&device);
@ -223,26 +210,22 @@ void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream, const boo
C10_CUDA_CHECK(status_); C10_CUDA_CHECK(status_);
} }
// printf("max_smem_per_block = %d\n", max_smem_per_block); // printf("max_smem_per_block = %d\n", max_smem_per_block);
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// if (params.h == params.h_k) { if (max_smem_per_block >= 116 * 1024) {
if (max_smem_per_block >= 116 * 1024) { if constexpr(!Is_dropout) { // 92KB
if constexpr(!Is_dropout) { // 92KB run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream, configure); } else { // 116 KB
} else { // 116 KB // This is faster for dropout since we don't have many registers to spare
// This is faster for dropout since we don't have many registers to spare run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream);
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
}
} else {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
} }
// } else { } else {
// run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
// } }
}); });
} }
template<typename T> template<typename T>
void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128; constexpr static int Headdim = 128;
int device; int device;
cudaGetDevice(&device); cudaGetDevice(&device);
@ -253,35 +236,30 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream, const bo
C10_CUDA_CHECK(status_); C10_CUDA_CHECK(status_);
} }
// printf("max_smem_per_block = %d\n", max_smem_per_block); // printf("max_smem_per_block = %d\n", max_smem_per_block);
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// if (params.h == params.h_k) { // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream, configure); // This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
// This is faster, in the case of sequence-parallel bwd (where we need fewer registers). // Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
// Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why. // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream, configure); if (max_smem_per_block >= 144 * 1024) {
if (max_smem_per_block >= 144 * 1024) { run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_dropout>(params, stream);
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_dropout>(params, stream, configure); // run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure); // run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream);
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>, Is_dropout>(params, stream);
// run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>, Is_dropout>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure); } else {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
} else { run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure); }
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
}
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream, configure);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream);
// } else {
// run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream, configure);
// }
}); });
} }
template<typename T> template<typename T>
void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 160; constexpr static int Headdim = 160;
int device; int device;
cudaGetDevice(&device); cudaGetDevice(&device);
@ -291,17 +269,17 @@ void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream, const bo
if (status_ != cudaSuccess) { if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_); C10_CUDA_CHECK(status_);
} }
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 116 * 1024) { if (max_smem_per_block >= 116 * 1024) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
} else { } else {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream);
} }
}); });
} }
template<typename T> template<typename T>
void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 192; constexpr static int Headdim = 192;
int device; int device;
cudaGetDevice(&device); cudaGetDevice(&device);
@ -311,25 +289,25 @@ void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream, const bo
if (status_ != cudaSuccess) { if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_); C10_CUDA_CHECK(status_);
} }
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 136 * 1024) { if (max_smem_per_block >= 136 * 1024) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
} else { } else {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_dropout>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_dropout>(params, stream);
} }
}); });
} }
template<typename T> template<typename T>
void run_mha_bwd_hdim224(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_hdim224(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 224; constexpr static int Headdim = 224;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
}); });
} }
template<typename T> template<typename T>
void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 256; constexpr static int Headdim = 256;
int device; int device;
cudaGetDevice(&device); cudaGetDevice(&device);
@ -339,14 +317,18 @@ void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream, const bo
if (status_ != cudaSuccess) { if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_); C10_CUDA_CHECK(status_);
} }
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 176 * 1024) { // H100 if (max_smem_per_block >= 176 * 1024) { // H100
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
} else { // A100, we don't do double buffering to save smem } else if (max_smem_per_block >= 144 * 1024) { // A100, we don't do double buffering to save smem
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_dropout>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_dropout>(params, stream);
} else { // sm86 and sm89, max smem is 99 KB. Only works without dropout. V in regs and no double buffering.
if constexpr (!Is_dropout) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, false>(params, stream);
}
} }
}); });
} }
}; // namespace pytorch_fmha }; // namespace pytorch_flash

View File

@ -0,0 +1,377 @@
/***************************************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
#include <cute/algorithm/copy.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
#include <ATen/native/transformers/cuda/flash_attn/block_info.h>
#include <ATen/native/transformers/cuda/flash_attn/kernel_traits.h>
#include <ATen/native/transformers/cuda/flash_attn/utils.h>
namespace pytorch_flash {
using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int THREADS_PER_ROW, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
inline __device__ void dot_do_o(Tensor<Engine0, Layout0> const &do_, Tensor<Engine0, Layout0> const &o,
Tensor<Engine1, Layout1> &dP_sum, const int gdP_col_stride, const float scale) {
static_assert(Layout0::rank == 3, "Only support 3D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(do_.layout() == o.layout());
// Reshape do_ and o from (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, 8 * kHeadDim / 64)
// The last coordinate is the "page".
Tensor do_reshaped = make_tensor(do_.data(), make_layout(get<1>(do_.layout()),
make_layout(get<0>(do_.layout()),
get<2>(do_.layout()))));
Tensor o_reshaped = make_tensor(o.data(), do_reshaped.layout());
Tensor do_fp32 = pytorch_flash::convert_type<float>(do_reshaped);
Tensor o_fp32 = pytorch_flash::convert_type<float>(o_reshaped);
#pragma unroll
for (int mi = 0; mi < size<0>(do_reshaped); ++mi) {
float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0);
#pragma unroll
for (int ni = 1; ni < size<1>(do_reshaped); ni++) {
dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni);
}
pytorch_flash::SumOp<float> sum_op;
dP_sum_cur = pytorch_flash::Allreduce<THREADS_PER_ROW>::run(dP_sum_cur, sum_op) * scale;
if (threadIdx.x % THREADS_PER_ROW == 0) {
dP_sum(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum_cur;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel.
// This is used in the case where we want to parallelize the backward across seqlen_k.
template<bool Clear_dQaccum=true, typename Kernel_traits, typename Params>
inline __device__ void compute_dot_do_o(const Params &params) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
const int m_block = blockIdx.x;
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.z;
// The thread index.
const int tidx = threadIdx.x;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
const BlockInfo binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb)
+ m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride;
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
+ (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM;
Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.do_row_stride, _1{}));
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.o_row_stride, _1{}));
Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.h * params.d_rounded, _1{}));
Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum),
Shape<Int<kBlockM>>{}, Stride<_1>{});
typename Kernel_traits::GmemTiledCopydO gmem_tiled_copy_dO;
auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx);
// TODO: careful, we're zeroing out dQaccum with type float4, but when
// we do atomicAdds, we use type float. The layouts are different. Check this.
typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum;
auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);
Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO);
Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO);
Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);
Tensor cdO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor tdOcdO = gmem_thr_copy_dO.partition_S(cdO);
// Allocate predicate tensors for k
Tensor tdOpdO = make_tensor<bool>(make_shape(size<2>(tdOgdO)));
// Set predicates for k bounds
#pragma unroll
for (int k = 0; k < size(tdOpdO); ++k) {tdOpdO(k) = get<1>(tdOcdO(0, 0, k)) < params.d;}
Tensor tdOrdO = make_fragment_like(tdOgdO);
Tensor tdOrO = make_fragment_like(tdOgO);
pytorch_flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_dO, tdOgdO, tdOrdO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM
);
pytorch_flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_dO, tdOgO, tdOrO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM
);
// By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final
// results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here,
// so that (dP - dP_sum) is on the same scale.
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, dP_sum,
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
if (Clear_dQaccum) {
// We're actually not zero'ing out all of dQaccum, but only the part that we're going to
// do atomicAdds on.
Tensor zero = make_fragment_like(tdQgdQaccum);
clear(zero);
cute::copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, typename Params>
inline __device__ void clear_dKVaccum(const Params &params) {
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
const int n_block = blockIdx.x;
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.z;
// The thread index.
const int tidx = threadIdx.x;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
const BlockInfo binfo(params, bidb);
if (n_block * kBlockN >= binfo.actual_seqlen_k) return;
const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + n_block * kBlockN) * params.d_rounded;
Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dk_accum_ptr) + row_offset_dkv_accum),
Shape<Int<kBlockN>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{});
Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dv_accum_ptr) + row_offset_dkv_accum),
Shape<Int<kBlockN>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{});
typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum;
auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx);
Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_D(gdKaccum);
Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_D(gdVaccum);
Tensor zero = make_fragment_like(tdKgdKaccum);
clear(zero);
cute::copy(gmem_tiled_copy_dKVaccum, zero, tdKgdKaccum);
cute::copy(gmem_tiled_copy_dKVaccum, zero, tdVgdVaccum);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert dQ from dQaccum (in float) to fp16/bf16.
// This is used in the case where we want to parallelize the backward across seqlen_k.
template<typename Kernel_traits, typename Params>
inline __device__ void convert_dQ(const Params &params, const int nsplits) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
// Shared memory.
extern __shared__ char smem_[];
const int m_block = blockIdx.x;
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.z;
// The thread index.
const int tidx = threadIdx.x;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
const BlockInfo binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)
+ m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
+ (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dq_ptr) + row_offset_dq),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.dq_row_stride, _1{}));
Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.h * params.d_rounded, _1{}));
Tensor sdQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutdQ{});
typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ;
auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dQaccum;
auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);
typename Kernel_traits::TiledMmadQ tiled_mma_dq;
auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq);
auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx);
Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum);
Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum));
Tensor tdQrdQaccum = make_fragment_like(tdQgdQaccum);
clear(acc_dq);
for (int s = 0; s < nsplits; ++s) {
cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, tdQrdQaccum);
#pragma unroll
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) += tdQrdQaccum(i); }
tdQgdQaccum.data() = tdQgdQaccum.data() + params.dq_accum_split_stride;
}
#pragma unroll
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; }
// Convert acc_dq from fp32 to fp16
Tensor rdQ = pytorch_flash::convert_type<Element>(acc_dq);
Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N)
cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
__syncthreads();
Tensor tdQrdQ = make_tensor<Element>(shape(tdQgdQ));
cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ);
Tensor cdQ = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ);
Tensor tdQpdQ = make_tensor<bool>(make_shape(size<2>(tdQgdQ)));
#pragma unroll
for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; }
// Clear_OOB_K must be false since we don't want to write zeros to gmem
pytorch_flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert dK and dV from dKaccum and dVaccum (in float) to fp16/bf16.
// This is used in the case where we want to parallelize the backward across seqlen_q.
template<typename Kernel_traits, typename Params>
inline __device__ void convert_dKV(const Params &params) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
// Shared memory.
extern __shared__ char smem_[];
const int n_block = blockIdx.x;
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.z;
// The thread index.
const int tidx = threadIdx.x;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
const BlockInfo binfo(params, bidb);
if (n_block * kBlockN >= binfo.actual_seqlen_k) return;
const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
+ n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
+ n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;
const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded
+ n_block * kBlockN) * params.d_rounded;
Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.dk_row_stride, _1{}));
Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.dv_row_stride, _1{}));
Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dk_accum_ptr) + row_offset_dkv_accum),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
Stride<Int<kHeadDim>, _1>{});
Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dv_accum_ptr) + row_offset_dkv_accum),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
Stride<Int<kHeadDim>, _1>{});
Tensor sdK = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutdKV{});
Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K)
typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dKV;
auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dKVaccum;
auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx);
typename Kernel_traits::TiledMmadKV tiled_mma_dkv;
auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv);
auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx);
Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);
Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);
Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_S(gdKaccum);
Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_S(gdVaccum);
Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
CUTE_STATIC_ASSERT_V(size(acc_dk) == size(tdKgdKaccum));
CUTE_STATIC_ASSERT_V(size(acc_dv) == size(tdVgdVaccum));
Tensor tdKrdKaccum = make_fragment_like(tdKgdKaccum);
Tensor tdVrdVaccum = make_fragment_like(tdVgdVaccum);
cute::copy(gmem_tiled_copy_dKVaccum, tdKgdKaccum, tdKrdKaccum);
cute::copy(gmem_tiled_copy_dKVaccum, tdVgdVaccum, tdVrdVaccum);
#pragma unroll
for (int i = 0; i < size(acc_dk); ++i) {
acc_dk(i) = tdKrdKaccum(i) * params.scale_softmax_rp_dropout;
}
#pragma unroll
for (int i = 0; i < size(acc_dv); ++i) {
acc_dv(i) = tdVrdVaccum(i) * params.rp_dropout;
}
// Convert acc_dk from fp32 to fp16
Tensor rdK = pytorch_flash::convert_type<Element>(acc_dk);
Tensor rdV = pytorch_flash::convert_type<Element>(acc_dv);
Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N)
Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N)
cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);
__syncthreads();
Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));
Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));
cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK);
cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV);
Tensor cdKV = make_identity_tensor(Shape<Int<kBlockN>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKgdK)));
#pragma unroll
for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
// Clear_OOB_K must be false since we don't want to write zeros to gmem
pytorch_flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
);
pytorch_flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
);
}
} // namespace flash

View File

@ -1,23 +1,23 @@
/****************************************************************************** /******************************************************************************
* Copyright (c) 2023, Tri Dao. * Copyright (c) 2024, Tri Dao.
******************************************************************************/ ******************************************************************************/
#pragma once #pragma once
#include <cmath>
#include <cute/algorithm/copy.hpp> #include <cute/algorithm/copy.hpp>
#include <cute/algorithm/gemm.hpp>
#include <cutlass/cutlass.h> #include <cutlass/cutlass.h>
#include <cutlass/array.h> #include <cutlass/array.h>
#include <cutlass/numeric_types.h> #include <cutlass/numeric_types.h>
#include <cutlass/numeric_conversion.h>
#include <ATen/native/transformers/cuda/flash_attn/block_info.h> #include <ATen/native/transformers/cuda/flash_attn/block_info.h>
#include <ATen/native/transformers/cuda/flash_attn/kernel_traits.h> #include <ATen/native/transformers/cuda/flash_attn/kernel_traits.h>
#include <ATen/native/transformers/cuda/flash_attn/utils.h> #include <ATen/native/transformers/cuda/flash_attn/utils.h>
#include <ATen/native/transformers/cuda/flash_attn/softmax.h> #include <ATen/native/transformers/cuda/flash_attn/softmax.h>
#include <ATen/native/transformers/cuda/flash_attn/philox.cuh> #include <ATen/native/transformers/cuda/flash_attn/mask.h>
#include <ATen/native/transformers/cuda/flash_attn/dropout.h>
#include <ATen/native/transformers/cuda/flash_attn/rotary.h>
namespace pytorch_flash { namespace pytorch_flash {
@ -25,57 +25,7 @@ using namespace cute;
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1, typename Tensor2> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum,
Tensor2 &acc_o, float softmax_scale_log2) {
if (Is_first) {
pytorch_flash::template reduce_max</*zero_init=*/true>(scores, scores_max);
pytorch_flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2);
pytorch_flash::reduce_sum(scores, scores_sum);
} else {
Tensor scores_max_prev = make_fragment_like(scores_max);
cute::copy(scores_max, scores_max_prev);
pytorch_flash::template reduce_max</*zero_init=*/false>(scores, scores_max);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), pytorch_flash::convert_layout_acc_rowcol(acc_o.layout()));
#pragma unroll
for (int mi = 0; mi < size(scores_max); ++mi) {
float scores_max_cur = !Check_inf
? scores_max(mi)
: (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi));
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
scores_sum(mi) *= scores_scale;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
}
pytorch_flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2);
Tensor scores_sum_cur = make_fragment_like(scores_sum);
pytorch_flash::reduce_sum(scores, scores_sum_cur);
#pragma unroll
for (int mi = 0; mi < size(scores_sum); ++mi) { scores_sum(mi) += scores_sum_cur(mi); }
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy>
inline __device__ void write_softmax_to_gmem(
Tensor<Engine0, Layout0> const &tOrP, Tensor<Engine1, Layout1> &tPgP, TiledCopy gmem_tiled_copy_P
) {
// Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N)
Layout l = tOrP.layout();
Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l))));
CUTE_STATIC_ASSERT_V(size<2>(tPgP) == _1{});
CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP));
#pragma unroll
for (int mi = 0; mi < size<1>(tPrP); ++mi) {
cute::copy(gmem_tiled_copy_P, tPrP(_, mi), tPgP(_, mi, 0));
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) { inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) {
using Element = typename Kernel_traits::Element; using Element = typename Kernel_traits::Element;
@ -93,6 +43,19 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kNWarps = Kernel_traits::kNWarps; constexpr int kNWarps = Kernel_traits::kNWarps;
auto seed_offset = at::cuda::philox::unpack(params.philox_args);
pytorch_flash::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t,
bidb, bidh, tidx, params.h);
// Save seed and offset for backward. If we don't have this here, the 0-th thread block might
// exit early and no one saves the rng state.
if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
if (params.philox_args.captured_) {
*params.seed = std::get<0>(seed_offset);
*params.extragraph_offset = std::get<1>(seed_offset);
}
}
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb); const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return; if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
@ -108,15 +71,6 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0. // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0.
// Otherwise we might read OOB elements from gK and gV. // Otherwise we might read OOB elements from gK and gV.
if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) { if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) {
// Save seed and offset for backward. If we don't have this here, the 0-th thread block might
// exit early and no one saves the rng state.
if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
auto seeds = at::cuda::philox::unpack(params.philox_args);
if (params.philox_args.captured_) {
*params.seed = std::get<0>(seeds);
*params.extragraph_offset = std::get<1>(seeds);
}
}
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
@ -191,8 +145,6 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopyP gmem_tiled_copy_P;
auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx);
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
@ -200,7 +152,6 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
Tensor tPgP = gmem_thr_copy_P.partition_D(gP);
typename Kernel_traits::TiledMma tiled_mma; typename Kernel_traits::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx); auto thr_mma = tiled_mma.get_thread_slice(tidx);
@ -208,6 +159,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
Tensor tSgS = thr_mma.partition_C(gP);
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K
// //
@ -228,10 +181,6 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
// TODO: this might need to change if we change the mma instruction in SM70
Tensor scores_max = make_tensor<ElementAccum>(Shape<Int<2 * size<1>(acc_o)>>{});
Tensor scores_sum = make_fragment_like(scores_max);
// //
// PREDICATES // PREDICATES
// //
@ -274,16 +223,11 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Prologue // Prologue
Tensor tQrQ = make_fragment_like(tQgQ);
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
pytorch_flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, pytorch_flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
binfo.actual_seqlen_q - m_block * kBlockM); binfo.actual_seqlen_q - m_block * kBlockM);
if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); }
// // Copy rmem to smem
// // copy(tQrQ, tQsQ);
// pytorch_flash::cp_async_wait<0>();
// __syncthreads();
// // if (cute::thread(1, 0)) { print(tQsQ); } // // if (cute::thread(1, 0)) { print(tQsQ); }
// // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{}); // // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{});
// // if (cute::thread0()) { print(sQNoSwizzle); } // // if (cute::thread0()) { print(sQNoSwizzle); }
@ -313,17 +257,13 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
} }
auto seeds = at::cuda::philox::unpack(params.philox_args);
if (params.philox_args.captured_) {
*params.seed = std::get<0>(seeds);
*params.extragraph_offset = std::get<1>(seeds);
}
unsigned long long seed = std::get<0>(seeds);
unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32;
clear(acc_o); clear(acc_o);
pytorch_flash::Softmax<2 * size<1>(acc_o)> softmax;
const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
pytorch_flash::Mask<Is_causal, Is_local, Has_alibi> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope);
// For performance reason, we separate out two kinds of iterations: // For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't. // those that need masking on S, and those that don't.
// We need masking on S for the very last block when K and V has length not multiple of kBlockN. // We need masking on S for the very last block when K and V has length not multiple of kBlockN.
@ -360,37 +300,9 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
); );
// if (cute::thread0()) { print(acc_s); } // if (cute::thread0()) { print(acc_s); }
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) mask.template apply_mask<Is_causal, Is_even_MN>(
Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout())); acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
// if (cute::thread0()) { print_tensor(scores); } );
// We don't put the masking before the matmul S = Q K^T because we don't clear sK
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
// can produce Inf / NaN.
if (!Is_causal && !Is_local) {
if (!Is_even_MN) { pytorch_flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
} else {
// Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
// Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N)
// static_assert(decltype(size<0>(taccScS))::value == 4);
// // Convert to ((2, 2), MMA_M, MMA_N) then take only the row indices.
// Tensor idx_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0);
// Tensor idx_rowcol = make_tensor(taccScS.data(), pytorch_flash::convert_layout_acc_rowcol(taccScS.layout()));
// pytorch_flash::apply_mask_causal_w_idx(scores, idx_rowcol, n_block * kBlockN, binfo.actual_seqlen_k,
// m_block * kBlockM);
// Idk why it's get<1> and not get<0> of the stride.
// if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); }
// I can't get the stride from idx_row
pytorch_flash::apply_mask_local</*HasWSLeft=*/Is_local>(
scores, n_block * kBlockN, binfo.actual_seqlen_k,
// m_block * kBlockM + get<0>(idx_row(0)),
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q, kNWarps * 16,
params.window_size_left, params.window_size_right
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16
);
// if (cute::thread0()) { print_tensor(scores); }
}
pytorch_flash::cp_async_wait<0>(); pytorch_flash::cp_async_wait<0>();
__syncthreads(); __syncthreads();
@ -405,33 +317,31 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// TODO: when we have key_padding_mask we'll need to Check_inf // TODO: when we have key_padding_mask we'll need to Check_inf
masking_step == 0 masking_step == 0
? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) ? softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2)
: softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); : softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2);
// Convert scores from fp32 to fp16/bf16 // Convert acc_s from fp32 to fp16/bf16
Tensor rP = pytorch_flash::convert_type<Element>(scores); Tensor rP = pytorch_flash::convert_type<Element>(acc_s);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
int block_col_idx = n_block * (kBlockN / 32); int block_col_idx = n_block * (kBlockN / 32);
if (Return_softmax) { if (Return_softmax) {
Tensor tOrP_copy = make_fragment_like(tOrP); Tensor rP_drop = make_fragment_like(rP);
cute::copy(tOrP, tOrP_copy); cute::copy(rP, rP_drop);
pytorch_flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>( dropout.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, rP_drop, block_row_idx, block_col_idx, kNWarps
block_row_idx, block_col_idx, kNWarps
); );
pytorch_flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P); cute::copy(rP_drop, tSgS);
tPgP.data() = tPgP.data() + (-kBlockN); tSgS.data() = tSgS.data() + (-kBlockN);
} }
if (Is_dropout) { if (Is_dropout) {
pytorch_flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset, dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps);
block_row_idx, block_col_idx, kNWarps);
} }
// if (cute::thread0()) { print(tOrP); }
pytorch_flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
// if (cute::thread0()) { print(tOrP); }
pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
// if (cute::thread0()) { print(scores); } // if (cute::thread0()) { print(scores); }
// This check is at the end of the loop since we always have at least 1 iteration // This check is at the end of the loop since we always have at least 1 iteration
@ -468,58 +378,37 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
cute::cp_async_fence(); cute::cp_async_fence();
} }
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) mask.template apply_mask</*Causal_mask=*/false>(
Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout())); acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { );
pytorch_flash::apply_mask_local(
scores, n_block * kBlockN, binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q, kNWarps * 16,
params.window_size_left, params.window_size_right
);
}
softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
Tensor rP = pytorch_flash::convert_type<Element>(scores); softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(acc_s, acc_o, params.scale_softmax_log2);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. Tensor rP = pytorch_flash::convert_type<Element>(acc_s);
Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
int block_col_idx = n_block * (kBlockN / 32); int block_col_idx = n_block * (kBlockN / 32);
if (Return_softmax) { if (Return_softmax) {
Tensor tOrP_copy = make_fragment_like(tOrP); Tensor rP_drop = make_fragment_like(rP);
cute::copy(tOrP, tOrP_copy); cute::copy(rP, rP_drop);
pytorch_flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>( dropout.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, rP_drop, block_row_idx, block_col_idx, kNWarps
block_row_idx, block_col_idx, kNWarps
); );
pytorch_flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P); cute::copy(rP_drop, tSgS);
tPgP.data() = tPgP.data() + (-kBlockN); tSgS.data() = tSgS.data() + (-kBlockN);
} }
if (Is_dropout) { if (Is_dropout) {
pytorch_flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset, dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps);
block_row_idx, block_col_idx, kNWarps);
} }
pytorch_flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
} }
// Epilogue // Epilogue
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) Tensor lse = softmax.template normalize_softmax_lse<Is_dropout>(acc_o, params.scale_softmax, params.rp_dropout);
Tensor acc_o_rowcol = make_tensor(acc_o.data(), pytorch_flash::convert_layout_acc_rowcol(acc_o.layout()));
Tensor lse = make_fragment_like(scores_sum);
#pragma unroll
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
float sum = scores_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
lse(mi) = (sum == 0.f || sum != sum) ? INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * params.rp_dropout;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
}
// if (cute::thread0()) { print(acc_o_rowcol); }
// Convert acc_o from fp32 to fp16/bf16 // Convert acc_o from fp32 to fp16/bf16
Tensor rO = pytorch_flash::convert_type<Element>(acc_o); Tensor rO = pytorch_flash::convert_type<Element>(acc_o);
@ -585,7 +474,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params> template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params>
inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) {
using Element = typename Kernel_traits::Element; using Element = typename Kernel_traits::Element;
@ -673,10 +562,17 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
+ m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
// We move K and V to the last block. // We move K and V to the last block.
const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb];
const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride;
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; const int block_table_idx = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN / params.page_block_size;
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) const int block_table_offset = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN - block_table_idx * params.page_block_size;
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; const index_t row_offset_k = block_table == nullptr
? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride
: block_table[block_table_idx] * params.k_batch_stride + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
const index_t row_offset_v = block_table == nullptr
? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride
: block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q), Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{}, Shape<Int<kBlockM>, Int<kHeadDim>>{},
@ -730,11 +626,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
// TODO: this might need to change if we change the mma instruction in SM70
Tensor scores_max = make_tensor<ElementAccum>(Shape<Int<2 * size<1>(acc_o)>>{});
Tensor scores_sum = make_fragment_like(scores_max);
//
// PREDICATES // PREDICATES
// //
@ -814,11 +705,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K)
const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN);
auto tKgK_data = tKgK.data();
auto tVgV_data = tVgV.data();
for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) { for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) {
pytorch_flash::copy_w_min_idx<Is_even_K>( pytorch_flash::copy_w_min_idx<Is_even_K>(
tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
); );
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride));
if (params.rotary_dim == 0) { if (params.rotary_dim == 0) {
pytorch_flash::copy_w_min_idx<Is_even_K>( pytorch_flash::copy_w_min_idx<Is_even_K>(
@ -844,19 +736,30 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
} }
} }
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride));
if (block_table == nullptr) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
} else {
if (n_block > n_block_copy_min) {
const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur];
const int offset_diff = block_table_offset_next - block_table_offset_cur;
tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride;
tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride;
}
}
} }
// Need this before we can read in K again, so that we'll see the updated K values. // Need this before we can read in K again, so that we'll see the updated K values.
__syncthreads(); __syncthreads();
if (n_block_max > n_block_copy_min) { tKgK.data() = tKgK_data;
tKgK.data() = tKgK.data() + (n_block_max - n_block_copy_min) * kBlockN * params.k_row_stride; tVgV.data() = tVgV_data;
tVgV.data() = tVgV.data() + (n_block_max - n_block_copy_min) * kBlockN * params.v_row_stride;
}
} }
// Read Q from gmem to smem, optionally apply rotary embedding. // Read Q from gmem to smem, optionally apply rotary embedding.
Tensor tQrQ = make_fragment_like(tQgQ);
if (!Append_KV || params.rotary_dim == 0) { if (!Append_KV || params.rotary_dim == 0) {
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
pytorch_flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, pytorch_flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
@ -907,6 +810,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
clear(acc_o); clear(acc_o);
pytorch_flash::Softmax<2 * size<1>(acc_o)> softmax;
const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
pytorch_flash::Mask<Is_causal, Is_local, Has_alibi> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope);
// For performance reason, we separate out two kinds of iterations: // For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't. // those that need masking on S, and those that don't.
// We need masking on S for the very last block when K and V has length not multiple of kBlockN. // We need masking on S for the very last block when K and V has length not multiple of kBlockN.
@ -927,7 +835,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// Advance gV // Advance gV
if (masking_step > 0) { if (masking_step > 0) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); if (block_table == nullptr) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
} else {
const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size;
const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size;
const int block_table_idx_next = n_block * kBlockN / params.page_block_size;
const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
}
pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
} else { } else {
// Clear the smem tiles to account for predicated off loads // Clear the smem tiles to account for predicated off loads
@ -943,21 +859,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
); );
// if (cute::thread0()) { print(acc_s); } // if (cute::thread0()) { print(acc_s); }
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) mask.template apply_mask<Is_causal, Is_even_MN>(
Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout())); acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
// if (cute::thread0()) { print(scores); } );
// We don't put the masking before the matmul S = Q K^T because we don't clear sK
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
// can produce Inf / NaN.
if (!Is_causal && !Is_local) {
if (!Is_even_MN) { pytorch_flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
} else {
pytorch_flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q, kNWarps * 16,
params.window_size_left, params.window_size_right
);
}
pytorch_flash::cp_async_wait<0>(); pytorch_flash::cp_async_wait<0>();
__syncthreads(); __syncthreads();
@ -966,7 +870,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (n_block > n_block_min) { if (n_block > n_block_min) {
// Advance gK // Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); if (block_table == nullptr) {
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
} else {
const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride;
}
pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization // This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions. // isn't right and we get race conditions.
@ -975,18 +887,17 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// We have key_padding_mask so we'll need to Check_inf // We have key_padding_mask so we'll need to Check_inf
masking_step == 0 masking_step == 0
? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) ? softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(acc_s, acc_o, params.scale_softmax_log2)
: softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); : softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(acc_s, acc_o, params.scale_softmax_log2);
// if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); }
// Convert scores from fp32 to fp16/bf16 // Convert acc_s from fp32 to fp16/bf16
Tensor rP = pytorch_flash::convert_type<Element>(scores); Tensor rP = pytorch_flash::convert_type<Element>(acc_s);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout())); Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
pytorch_flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
// if (cute::thread0()) { print(scores); }
// This check is at the end of the loop since we always have at least 1 iteration // This check is at the end of the loop since we always have at least 1 iteration
if (n_masking_steps > 1 && n_block <= n_block_min) { if (n_masking_steps > 1 && n_block <= n_block_min) {
@ -1002,7 +913,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
pytorch_flash::cp_async_wait<0>(); pytorch_flash::cp_async_wait<0>();
__syncthreads(); __syncthreads();
// Advance gV // Advance gV
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); if (block_table == nullptr) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
} else {
const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size;
const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size;
const int block_table_idx_next = n_block * kBlockN / params.page_block_size;
const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
}
pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
cute::cp_async_fence(); cute::cp_async_fence();
@ -1015,50 +934,38 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
__syncthreads(); __syncthreads();
if (n_block > n_block_min) { if (n_block > n_block_min) {
// Advance gK // Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); if (block_table == nullptr) {
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
} else {
const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride;
}
pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization // This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions. // isn't right and we get race conditions.
cute::cp_async_fence(); cute::cp_async_fence();
} }
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) mask.template apply_mask</*Causal_mask=*/false>(
Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout())); acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { );
pytorch_flash::apply_mask_local( softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(acc_s, acc_o, params.scale_softmax_log2);
scores, n_block * kBlockN, binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q, kNWarps * 16,
params.window_size_left, params.window_size_right
);
}
softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
Tensor rP = pytorch_flash::convert_type<Element>(scores); Tensor rP = pytorch_flash::convert_type<Element>(acc_s);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout())); Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
pytorch_flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
} }
// Epilogue // Epilogue
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(acc_o, params.scale_softmax);
Tensor acc_o_rowcol = make_tensor(acc_o.data(), pytorch_flash::convert_layout_acc_rowcol(acc_o.layout()));
// if (cute::thread0()) { print(acc_o_rowcol); }
Tensor lse = make_fragment_like(scores_sum);
#pragma unroll
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
float sum = scores_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : scores_max(mi) * params.scale_softmax + __logf(sum);
float scale = inv_sum;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
}
// if (cute::thread0()) { print(lse); } // if (cute::thread0()) { print(lse); }
// if (cute::thread0()) { print(acc_o_rowcol); }
Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
// Partition sO to match the accumulator partitioning // Partition sO to match the accumulator partitioning
@ -1135,7 +1042,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void compute_attn(const Params &params) { inline __device__ void compute_attn(const Params &params) {
const int m_block = blockIdx.x; const int m_block = blockIdx.x;
// The block index for the batch. // The block index for the batch.
@ -1151,12 +1058,12 @@ inline __device__ void compute_attn(const Params &params) {
// the attention matrix. This way, as long as we have the batch, head, and the location of // the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
pytorch_flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block); pytorch_flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params> template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params>
inline __device__ void compute_attn_splitkv(const Params &params) { inline __device__ void compute_attn_splitkv(const Params &params) {
const int m_block = blockIdx.x; const int m_block = blockIdx.x;
// The block index for the batch. // The block index for the batch.
@ -1165,7 +1072,7 @@ inline __device__ void compute_attn_splitkv(const Params &params) {
const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z;
const int n_split_idx = Split ? blockIdx.y : 0; const int n_split_idx = Split ? blockIdx.y : 0;
const int num_n_splits = Split ? gridDim.y : 1; const int num_n_splits = Split ? gridDim.y : 1;
pytorch_flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_local, Is_even_MN, Is_even_K, Split, Append_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits); pytorch_flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Split, Append_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
@ -1330,6 +1237,4 @@ inline __device__ void combine_attn_seqk_parallel(const Params &params) {
} }
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace pytorch_flash
} // namespace flash

View File

@ -12,27 +12,40 @@
namespace pytorch_flash { namespace pytorch_flash {
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Return_softmax> // Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
__global__ void flash_fwd_kernel(Flash_fwd_params params) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #define ARCH_SUPPORTS_FLASH
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false #define KERNEL_PARAM_MODIFIER __grid_constant__
pytorch_flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, Return_softmax>(params); #else
#define KERNEL_PARAM_MODIFIER
#endif
// Define a macro for unsupported architecture handling to centralize the error message
#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
// Use a macro to clean up kernel definitions
#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \
template<typename Kernel_traits, __VA_ARGS__> \
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params)
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // Enforce constraints
pytorch_flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params);
#else #else
printf("FATAL: FlashAttention requires to be build with sm80-sm90, but was built for < 8.0!"); FLASH_UNSUPPORTED_ARCH
#endif #endif
} }
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV> DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV) {
__global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) { #if defined(ARCH_SUPPORTS_FLASH)
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 pytorch_flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Split, Append_KV>(params);
pytorch_flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Is_even_MN, Is_even_K, Split, Append_KV>(params);
#else #else
printf("FATAL: FlashAttention requires to be build with sm80-sm90, but was built for < 8.0!"); FLASH_UNSUPPORTED_ARCH
#endif #endif
} }
template<typename Kernel_traits, int kBlockM, int Log_max_splits, bool Is_even_K> DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) {
__global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) {
static_assert(Log_max_splits >= 1); static_assert(Log_max_splits >= 1);
pytorch_flash::combine_attn_seqk_parallel<Kernel_traits, kBlockM, Log_max_splits, Is_even_K>(params); pytorch_flash::combine_attn_seqk_parallel<Kernel_traits, kBlockM, Log_max_splits, Is_even_K>(params);
} }
@ -52,27 +65,30 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
const bool is_even_K = params.d == Kernel_traits::kHeadDim; const bool is_even_K = params.d == Kernel_traits::kHeadDim;
const bool return_softmax = params.p_ptr != nullptr; const bool return_softmax = params.p_ptr != nullptr;
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
// Will only return softmax if dropout, to reduce compilation time. ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // Will only return softmax if dropout, to reduce compilation time.
// If return_softmax, set IsEvenMNConst to false to reduce number of templates // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates // If return_softmax, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>; // If Is_local, set Is_causal to false
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>; // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
if (smem_size >= 48 * 1024) { // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
C10_CUDA_CHECK(cudaFuncSetAttribute( // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); if (smem_size >= 48 * 1024) {
} C10_CUDA_CHECK(cudaFuncSetAttribute(
// int ctas_per_sm; kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( }
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); // int ctas_per_sm;
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params); // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
C10_CUDA_KERNEL_LAUNCH_CHECK(); // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}); });
}); });
}); });
@ -90,22 +106,24 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
const bool is_even_K = params.d == Kernel_traits::kHeadDim; const bool is_even_K = params.d == Kernel_traits::kHeadDim;
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] { BOOL_SWITCH(params.num_splits > 1, Split, [&] {
BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// If Is_local, set Is_causal to false // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV>; // If Is_local, set Is_causal to false
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>; auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>; // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
if (smem_size >= 48 * 1024) { // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
C10_CUDA_CHECK(cudaFuncSetAttribute( if (smem_size >= 48 * 1024) {
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); C10_CUDA_CHECK(cudaFuncSetAttribute(
} kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params); }
C10_CUDA_KERNEL_LAUNCH_CHECK(); kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}); });
}); });
}); });
@ -118,7 +136,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
// If headdim is divisible by 64, then we set kBlockM = 8, etc. // If headdim is divisible by 64, then we set kBlockM = 8, etc.
constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16);
dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM);
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
if (params.num_splits <= 2) { if (params.num_splits <= 2) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params); flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 4) { } else if (params.num_splits <= 4) {
@ -152,7 +170,7 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream)
template<typename T> template<typename T>
void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 32; constexpr static int Headdim = 32;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}); });
@ -162,7 +180,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T> template<typename T>
void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 64; constexpr static int Headdim = 64;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr(!Is_dropout) { if constexpr(!Is_dropout) {
// Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
@ -186,7 +204,7 @@ void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 96; constexpr static int Headdim = 96;
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = dprops->major == 8 && dprops->minor > 0; bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
if (is_sm8x) { if (is_sm8x) {
@ -212,7 +230,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128; constexpr static int Headdim = 128;
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = dprops->major == 8 && dprops->minor > 0; bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr(!Is_dropout) { if constexpr(!Is_dropout) {
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
@ -249,7 +267,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 160; constexpr static int Headdim = 160;
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = dprops->major == 8 && dprops->minor > 0; bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// For A100, H100, 128 x 32 is the fastest. // For A100, H100, 128 x 32 is the fastest.
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
@ -277,7 +295,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T> template<typename T>
void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 192; constexpr static int Headdim = 192;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr(!Is_dropout) { if constexpr(!Is_dropout) {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream); run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
@ -305,7 +323,7 @@ void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
C10_CUDA_CHECK(status_); C10_CUDA_CHECK(status_);
} }
// printf("max_smem_per_block = %d\n", max_smem_per_block); // printf("max_smem_per_block = %d\n", max_smem_per_block);
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream); run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
@ -336,7 +354,7 @@ void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
C10_CUDA_CHECK(status_); C10_CUDA_CHECK(status_);
} }
// printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// For A100, we want to run with 128 x 64 (128KB smem). // For A100, we want to run with 128 x 64 (128KB smem).
// For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
@ -353,4 +371,4 @@ void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
}); });
} }
}; // namespace pytorch_fmha }; // namespace pytorch_flash

View File

@ -1,5 +1,5 @@
/****************************************************************************** /******************************************************************************
* Copyright (c) 2023, Tri Dao. * Copyright (c) 2024, Tri Dao.
******************************************************************************/ ******************************************************************************/
#pragma once #pragma once
@ -26,7 +26,7 @@ struct Flash_kernel_traits {
#endif #endif
using ElementAccum = float; using ElementAccum = float;
using index_t = uint32_t; using index_t = int64_t;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using MMA_Atom_Arch = std::conditional_t< using MMA_Atom_Arch = std::conditional_t<
@ -91,20 +91,10 @@ struct Flash_fwd_kernel_traits : public Base {
SmemLayoutAtomQ{}, SmemLayoutAtomQ{},
Shape<Int<kBlockN>, Int<kHeadDim>>{})); Shape<Int<kBlockN>, Int<kHeadDim>>{}));
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128 // https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434
using SmemLayoutAtomVtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>, using SmemLayoutVtransposed = decltype(
Stride<_1, Int<kBlockKSmem>>>; composition(SmemLayoutKV{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockN>>{}, GenRowMajor{})));
using SmemLayoutAtomVtransposed = decltype( using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));
composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomVtransposedNoSwizzle{}));
using SmemLayoutVtransposed = decltype(tile_to_shape(
SmemLayoutAtomVtransposed{},
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
// Maybe the VtransposeNoSwizzle just needs to have the right shape
// And the strides don't matter?
using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape(
SmemLayoutAtomVtransposedNoSwizzle{},
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
// using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
using SmemLayoutAtomO = decltype( using SmemLayoutAtomO = decltype(
composition(Swizzle<kSwizzle, 3, 3>{}, composition(Swizzle<kSwizzle, 3, 3>{},
@ -116,10 +106,8 @@ struct Flash_fwd_kernel_traits : public Base {
using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>; using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>;
using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>; using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
static constexpr int kSmemQCount = size(SmemLayoutQ{}); static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element);
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
static constexpr int kSmemQSize = kSmemQCount * sizeof(Element);
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize;
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
@ -149,15 +137,6 @@ struct Flash_fwd_kernel_traits : public Base {
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{}, make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtom{}, GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP");
using GmemLayoutAtomP = Layout<Shape <Int<kNThreads / kGmemThreadsPerRowP>, Int<kGmemThreadsPerRowP>>,
Stride<Int<kGmemThreadsPerRowP>, _1>>;
using GmemTiledCopyP = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtomP{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
using GmemLayoutAtomOaccum = std::conditional_t< using GmemLayoutAtomOaccum = std::conditional_t<
kBlockKSmem == 32, kBlockKSmem == 32,
@ -244,26 +223,18 @@ struct Flash_bwd_kernel_traits : public Base {
SmemLayoutAtomKV{}, SmemLayoutAtomKV{},
make_shape(Int<kBlockN>{}, Int<kHeadDim>{}))); make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
using SmemLayoutAtomKtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>, using SmemLayoutKtransposed = decltype(
Stride<_1, Int<kBlockKSmem>>>; composition(SmemLayoutKV{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockN>>{}, GenRowMajor{})));
using SmemLayoutAtomKtransposed = decltype( using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{}));
composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomKtransposedNoSwizzle{}));
using SmemLayoutKtransposed = decltype(tile_to_shape(
SmemLayoutAtomKtransposed{},
make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
// Maybe the KtransposeNoSwizzle just needs to have the right shape
// And the strides don't matter?
using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape(
SmemLayoutAtomKtransposedNoSwizzle{},
make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
// using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn());
// TODO: generalize to other values of kBlockN // TODO: generalize to other values of kBlockN
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
// static constexpr int kPBlockN = kBlockN; // static constexpr int kPBlockN = kBlockN;
static_assert(kBlockN >= 64); // Temporarily disabling this for hdim 256 on sm86 and sm89
// static_assert(kBlockN >= 64);
static_assert(kBlockN >= 32);
// TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest.
static constexpr int kPBlockN = 64; static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32;
static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64);
// static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3);
static constexpr int kSwizzlePdS = 3; static constexpr int kSwizzlePdS = 3;
@ -274,30 +245,15 @@ struct Flash_bwd_kernel_traits : public Base {
using SmemLayoutPdS = decltype(tile_to_shape( using SmemLayoutPdS = decltype(tile_to_shape(
SmemLayoutAtomPdS{}, SmemLayoutAtomPdS{},
make_shape(Int<kBlockM>{}, Int<kBlockN>{}))); make_shape(Int<kBlockM>{}, Int<kBlockN>{})));
using SmemLayoutAtomPdStransposedNoSwizzle = Layout<Shape<Int<kPBlockN>, Int<kBlockM>>, using SmemLayoutPdStransposed = decltype(
Stride<_1, Int<kPBlockN>>>; composition(SmemLayoutPdS{}, make_layout(Shape<Int<kBlockN>, Int<kBlockM>>{}, GenRowMajor{})));
using SmemLayoutAtomPdStransposed = decltype( using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{}));
composition(Swizzle<kSwizzlePdS, 3, 3>{}, SmemLayoutAtomPdStransposedNoSwizzle{}));
using SmemLayoutPdStransposed = decltype(tile_to_shape(
SmemLayoutAtomPdStransposed{},
make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape(
SmemLayoutAtomPdStransposedNoSwizzle{},
make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
// using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn());
using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>; using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
using SmemLayoutAtomQdOtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockM>>, using SmemLayoutQdOtransposed = decltype(
Stride<_1, Int<kBlockKSmem>>>; composition(SmemLayoutQdO{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockM>>{}, GenRowMajor{})));
using SmemLayoutAtomQdOtransposed = decltype( using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{}));
composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomQdOtransposedNoSwizzle{}));
using SmemLayoutQdOtransposed = decltype(tile_to_shape(
SmemLayoutAtomQdOtransposed{},
make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape(
SmemLayoutAtomQdOtransposedNoSwizzle{},
make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
// using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn());
using SmemLayoutAtomdKV = decltype( using SmemLayoutAtomdKV = decltype(
composition(Swizzle<kSwizzle, 3, 3>{}, composition(Swizzle<kSwizzle, 3, 3>{},
@ -317,16 +273,12 @@ struct Flash_bwd_kernel_traits : public Base {
make_shape(Int<kBlockM>{}, Int<kHeadDim>{}))); make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
using SmemCopyAtomdQ = Copy_Atom<DefaultCopy, elem_type>; using SmemCopyAtomdQ = Copy_Atom<DefaultCopy, elem_type>;
static constexpr int kSmemQdOCount = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3); // Double buffer for sQ // Double buffer for sQ
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element);
static constexpr int kSmemdSCount = size(SmemLayoutPdS{}); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
static constexpr int kSmemPCount = size(SmemLayoutPdS{}); static constexpr int kSmemdSSize = size(SmemLayoutPdS{}) * sizeof(Element);
static constexpr int kSmemdQCount = size(SmemLayoutdQ{}); static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element);
static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element); static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element);
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element);
static constexpr int kSmemPSize = kSmemPCount * sizeof(Element);
static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element);
static constexpr int kSmemSize = kSmemQdOSize static constexpr int kSmemSize = kSmemQdOSize
+ (!Is_V_in_regs + (!Is_V_in_regs
? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)
@ -335,9 +287,6 @@ struct Flash_bwd_kernel_traits : public Base {
+ (!Is_V_in_regs + (!Is_V_in_regs
? kSmemKVSize + kSmemdSSize + kSmemPSize ? kSmemKVSize + kSmemdSSize + kSmemPSize
: std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize));
static constexpr int kSmemSize1rowblock = kSmemQdOSize / 3 * 2 + kSmemKVSize / 2 * 3
+ kSmemdSSize + kSmemPSize;
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");

View File

@ -1,161 +0,0 @@
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include <cute/algorithm/copy.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/layout/layout.h>
#include <cutlass/numeric_types.h>
namespace pytorch_flash{
using namespace cute;
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::half_t>
struct Flash_kernel_traits_sm90 {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using Element = elem_type;
static constexpr bool Has_cp_async = true;
#else
using Element = cutlass::half_t;
static constexpr bool Has_cp_async = false;
#endif
using ElementAccum = float;
using index_t = uint32_t;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using MMA_Atom_Arch = std::conditional_t<
std::is_same_v<elem_type, cutlass::half_t>,
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
>;
using ValLayoutMNK = Layout<Shape<_1, _2, _1>>;
#else
using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
using ValLayoutMNK = Layout<Shape<_1, _2, _2>>;
#endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, elem_type>;
using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, elem_type>;
#else
using SmemCopyAtom = Copy_Atom<DefaultCopy, elem_type>;
using SmemCopyAtomTransposed = Copy_Atom<DefaultCopy, elem_type>;
#endif
};
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, bool Is_Q_in_regs_=false, bool Share_Q_K_smem_=false, typename elem_type=cutlass::half_t,
typename Base=Flash_kernel_traits_sm90<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >
struct Flash_fwd_kernel_traits : public Base {
using Element = typename Base::Element;
using ElementAccum = typename Base::ElementAccum;
using index_t = typename Base::index_t;
static constexpr bool Has_cp_async = Base::Has_cp_async;
using SmemCopyAtom = typename Base::SmemCopyAtom;
using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
static constexpr bool Share_Q_K_smem = Share_Q_K_smem_;
static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem;
// The number of threads.
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * 32;
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kHeadDim = kHeadDim_;
static_assert(kHeadDim % 32 == 0);
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
using TiledMma = TiledMMA<
typename Base::MMA_Atom_Arch,
Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
using SmemLayoutAtomQ = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
// This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128
Layout<Shape<_8, Int<kBlockKSmem>>,
Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutQ = decltype(tile_to_shape(
SmemLayoutAtomQ{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemLayoutKV = decltype(tile_to_shape(
SmemLayoutAtomQ{},
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
using SmemLayoutAtomVtransposed = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
Stride<_1, Int<kBlockKSmem>>>{}));
using SmemLayoutVtransposed = decltype(tile_to_shape(
SmemLayoutAtomVtransposed{},
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
// Maybe the VtransposeNoSwizzle just needs to have the right shape
// And the strides don't matter?
using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
using SmemLayoutAtomO = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
Layout<Shape<Int<8>, Int<kBlockKSmem>>,
Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutO = decltype(tile_to_shape(
SmemLayoutAtomO{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemCopyAtomO = Copy_Atom<DefaultCopy, elem_type>;
static constexpr int kSmemQCount = size(SmemLayoutQ{});
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
static constexpr int kSmemQSize = kSmemQCount * sizeof(Element);
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize;
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
// Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts.
// For example, for d=128, smem is split into 2 "pages", each page takes care of columns
// 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem,
// thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page,
// to the same banks.
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
Stride<Int<kGmemThreadsPerRow>, _1>>;
// We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
// from the same address by the same threadblock. This is slightly faster.
using Gmem_copy_struct = std::conditional_t<
Has_cp_async,
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
DefaultCopy
>;
using GmemTiledCopyQKV = decltype(
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using GmemTiledCopyO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP");
using GmemLayoutAtomP = Layout<Shape <Int<kNThreads / kGmemThreadsPerRowP>, Int<kGmemThreadsPerRowP>>,
Stride<Int<kGmemThreadsPerRowP>, _1>>;
using GmemTiledCopyP = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
GmemLayoutAtomP{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
};
} // namespace pytorch_flash
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -8,7 +8,7 @@
namespace pytorch_flash{ namespace pytorch_flash{
template<> template<>
void run_mha_bwd_<cutlass::bfloat16_t, 128>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_<cutlass::bfloat16_t, 128>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim128<cutlass::bfloat16_t>(params, stream, configure); run_mha_bwd_hdim128<cutlass::bfloat16_t>(params, stream);
} }
} // namespace pytorch_flash } // namespace pytorch_flash

View File

@ -8,7 +8,7 @@
namespace pytorch_flash{ namespace pytorch_flash{
template<> template<>
void run_mha_bwd_<cutlass::half_t, 128>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_<cutlass::half_t, 128>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim128<cutlass::half_t>(params, stream, configure); run_mha_bwd_hdim128<cutlass::half_t>(params, stream);
} }
} // namespace pytorch_flash } // namespace pytorch_flash

View File

@ -8,7 +8,7 @@
namespace pytorch_flash{ namespace pytorch_flash{
template<> template<>
void run_mha_bwd_<cutlass::bfloat16_t, 160>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_<cutlass::bfloat16_t, 160>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim160<cutlass::bfloat16_t>(params, stream, configure); run_mha_bwd_hdim160<cutlass::bfloat16_t>(params, stream);
} }
} // namespace pytorch_flash } // namespace pytorch_flash

View File

@ -8,7 +8,7 @@
namespace pytorch_flash{ namespace pytorch_flash{
template<> template<>
void run_mha_bwd_<cutlass::half_t, 160>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_<cutlass::half_t, 160>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim160<cutlass::half_t>(params, stream, configure); run_mha_bwd_hdim160<cutlass::half_t>(params, stream);
} }
} // namespace pytorch_flash } // namespace pytorch_flash

View File

@ -8,7 +8,7 @@
namespace pytorch_flash{ namespace pytorch_flash{
template<> template<>
void run_mha_bwd_<cutlass::bfloat16_t, 192>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_<cutlass::bfloat16_t, 192>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim192<cutlass::bfloat16_t>(params, stream, configure); run_mha_bwd_hdim192<cutlass::bfloat16_t>(params, stream);
} }
} // namespace pytorch_flash } // namespace pytorch_flash

View File

@ -8,7 +8,7 @@
namespace pytorch_flash{ namespace pytorch_flash{
template<> template<>
void run_mha_bwd_<cutlass::half_t, 192>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_<cutlass::half_t, 192>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim192<cutlass::half_t>(params, stream, configure); run_mha_bwd_hdim192<cutlass::half_t>(params, stream);
} }
} // namespace pytorch_flash } // namespace pytorch_flash

View File

@ -8,7 +8,7 @@
namespace pytorch_flash{ namespace pytorch_flash{
template<> template<>
void run_mha_bwd_<cutlass::bfloat16_t, 224>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_<cutlass::bfloat16_t, 224>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim224<cutlass::bfloat16_t>(params, stream, configure); run_mha_bwd_hdim224<cutlass::bfloat16_t>(params, stream);
} }
} // namespace pytorch_flash } // namespace pytorch_flash

View File

@ -8,7 +8,7 @@
namespace pytorch_flash{ namespace pytorch_flash{
template<> template<>
void run_mha_bwd_<cutlass::half_t, 224>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_<cutlass::half_t, 224>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim224<cutlass::half_t>(params, stream, configure); run_mha_bwd_hdim224<cutlass::half_t>(params, stream);
} }
} // namespace pytorch_flash } // namespace pytorch_flash

View File

@ -8,7 +8,7 @@
namespace pytorch_flash{ namespace pytorch_flash{
template<> template<>
void run_mha_bwd_<cutlass::bfloat16_t, 256>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_<cutlass::bfloat16_t, 256>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim256<cutlass::bfloat16_t>(params, stream, configure); run_mha_bwd_hdim256<cutlass::bfloat16_t>(params, stream);
} }
} // namespace pytorch_flash } // namespace pytorch_flash

View File

@ -8,7 +8,7 @@
namespace pytorch_flash{ namespace pytorch_flash{
template<> template<>
void run_mha_bwd_<cutlass::half_t, 256>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_<cutlass::half_t, 256>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim256<cutlass::half_t>(params, stream, configure); run_mha_bwd_hdim256<cutlass::half_t>(params, stream);
} }
} // namespace pytorch_flash } // namespace pytorch_flash

View File

@ -8,7 +8,7 @@
namespace pytorch_flash{ namespace pytorch_flash{
template<> template<>
void run_mha_bwd_<cutlass::bfloat16_t, 32>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_<cutlass::bfloat16_t, 32>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim32<cutlass::bfloat16_t>(params, stream, configure); run_mha_bwd_hdim32<cutlass::bfloat16_t>(params, stream);
} }
} // namespace pytorch_flash } // namespace pytorch_flash

View File

@ -8,7 +8,7 @@
namespace pytorch_flash{ namespace pytorch_flash{
template<> template<>
void run_mha_bwd_<cutlass::half_t, 32>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_<cutlass::half_t, 32>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim32<cutlass::half_t>(params, stream, configure); run_mha_bwd_hdim32<cutlass::half_t>(params, stream);
} }
} // namespace pytorch_flash } // namespace pytorch_flash

View File

@ -8,7 +8,7 @@
namespace pytorch_flash{ namespace pytorch_flash{
template<> template<>
void run_mha_bwd_<cutlass::bfloat16_t, 64>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_<cutlass::bfloat16_t, 64>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim64<cutlass::bfloat16_t>(params, stream, configure); run_mha_bwd_hdim64<cutlass::bfloat16_t>(params, stream);
} }
} // namespace pytorch_flash } // namespace pytorch_flash

View File

@ -8,7 +8,7 @@
namespace pytorch_flash{ namespace pytorch_flash{
template<> template<>
void run_mha_bwd_<cutlass::half_t, 64>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_<cutlass::half_t, 64>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim64<cutlass::half_t>(params, stream, configure); run_mha_bwd_hdim64<cutlass::half_t>(params, stream);
} }
} // namespace pytorch_flash } // namespace pytorch_flash

View File

@ -8,7 +8,7 @@
namespace pytorch_flash{ namespace pytorch_flash{
template<> template<>
void run_mha_bwd_<cutlass::bfloat16_t, 96>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_<cutlass::bfloat16_t, 96>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim96<cutlass::bfloat16_t>(params, stream, configure); run_mha_bwd_hdim96<cutlass::bfloat16_t>(params, stream);
} }
} // namespace pytorch_flash } // namespace pytorch_flash

View File

@ -8,7 +8,7 @@
namespace pytorch_flash{ namespace pytorch_flash{
template<> template<>
void run_mha_bwd_<cutlass::half_t, 96>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_<cutlass::half_t, 96>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim96<cutlass::half_t>(params, stream, configure); run_mha_bwd_hdim96<cutlass::half_t>(params, stream);
} }
} // namespace pytorch_flash } // namespace pytorch_flash

View File

@ -27,8 +27,8 @@ template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params
KERNEL_IMPL_TEMPLATE_BWD = """ KERNEL_IMPL_TEMPLATE_BWD = """
template<> template<>
void run_mha_bwd_<{DTYPE}, {HEAD_DIM}>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {{ void run_mha_bwd_<{DTYPE}, {HEAD_DIM}>(Flash_bwd_params &params, cudaStream_t stream) {{
run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream, configure); run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream);
}} }}
""" """

View File

@ -0,0 +1,213 @@
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
#include <cute/tensor.hpp>
namespace pytorch_flash {
using namespace cute;
template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
const int col_idx_offset_ = 0) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
if (col_idx >= max_seqlen_k) {
// Without the "make_coord" we get wrong results
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
tensor(mi, make_coord(j, nj)) = -INFINITY;
}
}
}
}
}
template <bool HasWSLeft=true, typename Engine, typename Layout>
__forceinline__ __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset,
const int max_seqlen_q, const int warp_row_stride,
const int window_size_left, const int window_size_right) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
#pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) {
const int row_idx = row_idx_base + i * 8;
const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) {
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
}
}
}
// if (cute::thread0()) {
// printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k);
// print(tensor(make_coord(i, mi), _));
// // print(tensor(_, j + nj * size<1, 0>(tensor)));
// }
}
}
}
template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset,
const int max_seqlen_q, const int warp_row_stride) {
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
apply_mask_local</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset,
max_seqlen_q, warp_row_stride, -1, 0);
}
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void apply_mask_causal_w_idx(
Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 2, "Only support 2D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol));
CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0)));
#pragma unroll
for (int ni = 0; ni < size<1, 1>(tensor); ++ni) {
if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) {
tensor(mi, ni) = -INFINITY;
}
}
// if (cute::thread0()) {
// printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k);
// print(tensor(_, make_coord(j, ni)));
// // print(tensor(_, j + ni * size<1, 0>(tensor)));
// }
}
}
template <bool Is_causal, bool Is_local, bool Has_alibi>
struct Mask {
const int max_seqlen_k, max_seqlen_q;
const int window_size_left, window_size_right;
const float alibi_slope;
__forceinline__ __device__ Mask(const int max_seqlen_k, const int max_seqlen_q,
const int window_size_left, const int window_size_right,
const float alibi_slope=0.f)
: max_seqlen_k(max_seqlen_k)
, max_seqlen_q(max_seqlen_q)
, window_size_left(window_size_left)
, window_size_right(window_size_right)
, alibi_slope(!Has_alibi ? 0.0 : alibi_slope) {
};
// Causal_mask: whether this particular iteration needs causal masking
template <bool Causal_mask=false, bool Is_even_MN=true, typename Engine, typename Layout>
__forceinline__ __device__ void apply_mask(Tensor<Engine, Layout> &tensor_,
const int col_idx_offset_,
const int row_idx_offset,
const int warp_row_stride) {
static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local");
static_assert(Layout::rank == 3, "Only support 3D Tensor");
static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4");
static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN;
// if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); }
if constexpr (Need_masking) {
// Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor tensor = make_tensor(tensor_.data(), pytorch_flash::convert_layout_acc_rowcol(tensor_.layout()));
// Do we need both row and column indices, or just column incides?
static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask;
const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
if constexpr (Col_idx_only) {
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
// No causal, no local
if constexpr (Has_alibi) {
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
}
if constexpr (!Is_even_MN) {
if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; }
}
}
}
}
} else {
#pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) {
const int row_idx = row_idx_base + i * 8;
const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
if constexpr (Has_alibi) {
if constexpr (Is_causal) {
tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope * col_idx;
} else {
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
}
}
if constexpr (Causal_mask) {
if (col_idx >= col_idx_limit_right) {
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
}
}
if constexpr (Is_local) {
if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) {
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
}
}
if constexpr (!Causal_mask && !Is_local && !Is_even_MN) {
// Causal and Local already handles MN masking
if (col_idx >= max_seqlen_k) {
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
}
}
}
}
}
}
}
}
};
};
} // namespace pytorch_flash

View File

@ -11,7 +11,7 @@ struct ull2 {
unsigned long long y; unsigned long long y;
}; };
inline __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) { __forceinline__ __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) {
uint2 *res; uint2 *res;
unsigned long long tmp; unsigned long long tmp;
asm ("mul.wide.u32 %0, %1, %2;\n\t" asm ("mul.wide.u32 %0, %1, %2;\n\t"
@ -21,7 +21,7 @@ inline __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) {
return *res; return *res;
} }
inline __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) { __forceinline__ __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) {
constexpr unsigned long kPhiloxSA = 0xD2511F53; constexpr unsigned long kPhiloxSA = 0xD2511F53;
constexpr unsigned long kPhiloxSB = 0xCD9E8D57; constexpr unsigned long kPhiloxSB = 0xCD9E8D57;
uint2 res0 = mulhilo32(kPhiloxSA, ctr.x); uint2 res0 = mulhilo32(kPhiloxSA, ctr.x);
@ -30,7 +30,7 @@ inline __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) {
return ret; return ret;
} }
inline __device__ uint4 philox(unsigned long long seed, __forceinline__ __device__ uint4 philox(unsigned long long seed,
unsigned long long subsequence, unsigned long long subsequence,
unsigned long long offset) { unsigned long long offset) {
constexpr unsigned long kPhilox10A = 0x9E3779B9; constexpr unsigned long kPhilox10A = 0x9E3779B9;
@ -51,117 +51,3 @@ inline __device__ uint4 philox(unsigned long long seed,
} }
} // namespace flash } // namespace flash
namespace {
class Philox {
public:
__device__ inline Philox(unsigned long long seed,
unsigned long long subsequence,
unsigned long long offset)
: STATE(0)
, seed_(seed)
, offset_(offset)
, key(reinterpret_cast<const uint2&>(seed)) {
//key.x = (unsigned int)seed;
//key.y = (unsigned int)(seed >> 32);
//counter = make_uint4(0, 0, 0, 0);
//counter.z = (unsigned int)(subsequence);
//counter.w = (unsigned int)(subsequence >> 32);
//STATE = 0;
//incr_n(offset / 4);
// key = reinterpret_cast<const uint2&>(seed);
ull2 * tmp = reinterpret_cast<ull2*>(&counter);
tmp->x = offset / 4;
tmp->y = subsequence;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("Philox counter: %d, %d, %d, %d\n", counter.x, counter.y, counter.z, counter.w);
// }
}
__device__ inline uint4 operator()() {
// // if (STATE == 0) {
// uint4 counter_ = counter;
// uint2 key_ = key;
// // 7-round philox
// #pragma unroll
// for (int i = 0; i < 6; i++) {
// counter_ = pytorch_flash::philox_single_round(counter_, key_);
// key_.x += (kPhilox10A);
// key_.y += (kPhilox10B);
// }
// // output = philox_single_round(counter_, key_);
// uint4 output = pytorch_flash::philox_single_round(counter_, key_);
// // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// // printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
// // printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w);
// // }
// incr();
// // }
// // return a float4 directly
// // unsigned long ret;
// // switch(STATE) {
// // case 0: ret = output.x; break;
// // case 1: ret = output.y; break;
// // case 2: ret = output.z; break;
// // case 3: ret = output.w; break;
// //}
// // STATE = (STATE + 1) % 4;
// return output;
return pytorch_flash::philox(seed_, offset_, offset_);
}
private:
unsigned long long offset_, seed_;
struct ull2 {
uint64_t x;
uint64_t y;
};
uint4 counter;
// uint4 output;
const uint2 key;
unsigned int STATE;
__device__ inline void incr_n(unsigned long long n) {
unsigned int nlo = (unsigned int)(n);
unsigned int nhi = (unsigned int)(n >> 32);
counter.x += nlo;
if (counter.x < nlo)
nhi++;
counter.y += nhi;
if (nhi <= counter.y)
return;
if (++counter.z)
return;
++counter.w;
}
__device__ uint4 incr128 (uint4 ctr)
{
uint4 res;
asm ("add.cc.u32 %0, %4, %8;\n\t"
"addc.cc.u32 %1, %5, %9;\n\t"
"addc.cc.u32 %2, %6, %10;\n\t"
"addc.u32 %3, %7, %11;\n\t"
: "=r"(res.x), "=r"(res.y), "=r"(res.z), "=r"(res.w)
: "r"(ctr.x), "r"(ctr.y), "r"(ctr.z), "r"(ctr.w),
"n"(1), "n"(0), "n"(0), "n"(0));
return res;
}
__device__ inline void incr() {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("Counter before: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
// }
counter = incr128(counter);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("Counter after: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
// }
}
static const unsigned long kPhilox10A = 0x9E3779B9;
static const unsigned long kPhilox10B = 0xBB67AE85;
// static const unsigned long kPhiloxSA = 0xD2511F53;
// static const unsigned long kPhiloxSB = 0xCD9E8D57;
};
} // namespace pytorch_flash

View File

@ -0,0 +1,152 @@
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
#include <cute/algorithm/copy.hpp>
#include <ATen/native/transformers/cuda/flash_attn/utils.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace pytorch_flash {
using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool Is_even_K=true, bool Clear_OOB_K=true,
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
__forceinline__ __device__ void copy_rotary_interleaved(Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D,
Tensor<Engine2, Layout2> const &Cos,
Tensor<Engine2, Layout2> const &Sin,
Tensor<Engine3, Layout3> const &identity_MN,
const int max_MN, const int min_MN,
const int dim, const int rotary_dim) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K
CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K
static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2);
static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
Tensor rCos = make_fragment_like(Cos);
Tensor rSin = make_fragment_like(Sin);
Tensor rS = make_fragment_like(S);
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
cute::copy(S(_, m, k), rS(_, m, k));
if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
cute::copy(Cos(_, m, k), rCos(_, m, k));
cute::copy(Sin(_, m, k), rSin(_, m, k));
Tensor S_fp32 = convert_type<float>(rS(_, m, k));
Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
#pragma unroll
for (int i = 0; i < size<0>(rS) / 2; ++i) {
float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i);
float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i);
S_fp32(2 * i) = real;
S_fp32(2 * i + 1) = imag;
}
// Idk but I need to copy for the convert_type to work
Tensor S_fp32_copy = make_fragment_like(S_fp32);
cute::copy(S_fp32, S_fp32_copy);
using T = typename Engine0::value_type;
Tensor S_og_type = convert_type<T>(S_fp32_copy);
cute::copy(S_og_type, rS(_, m, k));
}
cute::copy(rS(_, m, k), D(_, m, k));
} else if (Clear_OOB_K) {
cute::clear(D(_, m, k));
}
}
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool Is_even_K=true, bool Clear_OOB_K=true,
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
__forceinline__ __device__ void copy_rotary_contiguous(Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D,
Tensor<Engine2, Layout2> const &Cos,
Tensor<Engine2, Layout2> const &Sin,
Tensor<Engine3, Layout3> const &identity_MN,
const int max_MN, const int min_MN,
const int dim, const int rotary_dim) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA
CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin));
static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
Tensor rCos = make_fragment_like(Cos);
Tensor rSin = make_fragment_like(Sin);
Tensor rS = make_fragment_like(S);
Tensor rS_other = make_fragment_like(rS(_, 0, 0));
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
cute::copy(S(_, m, k), rS(_, m, k));
if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2;
Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout());
cute::copy(gS_other, rS_other);
// if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); }
Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout());
Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout());
cute::copy(gCos, rCos(_, m, k));
cute::copy(gSin, rSin(_, m, k));
// if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); }
Tensor S_fp32 = convert_type<float>(rS(_, m, k));
Tensor S_other_fp32 = convert_type<float>(rS_other);
Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
#pragma unroll
for (int i = 0; i < size<0>(rS); ++i) {
S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i));
}
// Idk but I need to copy for the convert_type to work
Tensor S_fp32_copy = make_fragment_like(S_fp32);
cute::copy(S_fp32, S_fp32_copy);
using T = typename Engine0::value_type;
Tensor S_og_type = convert_type<T>(S_fp32_copy);
cute::copy(S_og_type, rS(_, m, k));
// if (cute::thread0()) { print_tensor(rS(_, m, k)); }
}
cute::copy(rS(_, m, k), D(_, m, k));
} else if (Clear_OOB_K) {
cute::clear(D(_, m, k));
}
}
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace pytorch_flash

View File

@ -1,34 +1,15 @@
/****************************************************************************** /******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2024, Tri Dao.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/ ******************************************************************************/
#pragma once #pragma once
#include <cmath> #include <cmath>
#include <cuda_fp16.h>
#include <cute/tensor.hpp>
#include <cutlass/numeric_types.h>
#include <ATen/native/transformers/cuda/flash_attn/philox.cuh> #include <ATen/native/transformers/cuda/flash_attn/philox.cuh>
#include <ATen/native/transformers/cuda/flash_attn/utils.h> #include <ATen/native/transformers/cuda/flash_attn/utils.h>
@ -39,7 +20,7 @@ using namespace cute;
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator> template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ inline void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) { __device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
@ -54,7 +35,7 @@ __device__ inline void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Te
} }
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator> template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ inline void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) { __device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
CUTE_STATIC_ASSERT_V(size(dst) == size(src)); CUTE_STATIC_ASSERT_V(size(dst) == size(src));
#pragma unroll #pragma unroll
for (int i = 0; i < size(dst); i++){ for (int i = 0; i < size(dst); i++){
@ -63,26 +44,26 @@ __device__ inline void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Eng
} }
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator> template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ inline void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) { __device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
thread_reduce_<zero_init>(tensor, summary, op); thread_reduce_<zero_init>(tensor, summary, op);
quad_allreduce_(summary, summary, op); quad_allreduce_(summary, summary, op);
} }
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1> template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ inline void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){ __device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
MaxOp<float> max_op; MaxOp<float> max_op;
reduce_<zero_init>(tensor, max, max_op); reduce_<zero_init>(tensor, max, max_op);
} }
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1> template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ inline void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){ __device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
SumOp<float> sum_op; SumOp<float> sum_op;
reduce_(tensor, sum, sum_op); thread_reduce_<zero_init>(tensor, sum, sum_op);
} }
// Apply the exp to all the elements. // Apply the exp to all the elements.
template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1> template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
inline __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) { __forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
@ -104,7 +85,7 @@ inline __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor
// Apply the exp to all the elements. // Apply the exp to all the elements.
template <bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1> template <bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) { __forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
@ -134,171 +115,67 @@ inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tens
} }
} }
template <typename Engine, typename Layout> ////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
const int col_idx_offset_ = 0) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
if (col_idx >= max_seqlen_k) {
// Without the "make_coord" we get wrong results
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
tensor(mi, make_coord(j, nj)) = -INFINITY;
}
}
}
}
}
template <bool HasWSLeft=true, typename Engine, typename Layout> template <int kNRows>
inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_, struct Softmax {
const int max_seqlen_k, const int row_idx_offset_,
const int max_seqlen_q, const int warp_row_stride, using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
const int window_size_left, const int window_size_right) { TensorT row_max, row_sum;
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor"); __forceinline__ __device__ Softmax() {};
const int lane_id = threadIdx.x % 32;
// const int row_idx_offset = row_idx_offset_ + lane_id / 4; template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1>
const int row_idx_offset = row_idx_offset_; __forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) {
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
#pragma unroll Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout()));
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { static_assert(decltype(size<0>(scores))::value == kNRows);
const int row_idx_base = row_idx_offset + mi * warp_row_stride; if (Is_first) {
#pragma unroll pytorch_flash::template reduce_max</*zero_init=*/true>(scores, row_max);
for (int i = 0; i < size<0, 0>(tensor); ++i) { pytorch_flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
const int row_idx = row_idx_base + i * 8; pytorch_flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); } else {
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
pytorch_flash::template reduce_max</*zero_init=*/false>(scores, row_max);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), pytorch_flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll #pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { for (int mi = 0; mi < size(row_max); ++mi) {
const int col_idx_base = col_idx_offset + nj * 8; float scores_max_cur = !Check_inf
? row_max(mi)
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
row_sum(mi) *= scores_scale;
#pragma unroll #pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) { for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
const int col_idx = col_idx_base + j;
if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) {
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
}
}
} }
// if (cute::thread0()) { pytorch_flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
// printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k); // We don't do the reduce across threads here since we don't need to use the row_sum.
// print(tensor(make_coord(i, mi), _)); // We do that reduce at the end when we need to normalize the softmax.
// // print(tensor(_, j + nj * size<1, 0>(tensor))); pytorch_flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
// }
} }
}
}
template <typename Engine, typename Layout>
inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_,
const int max_seqlen_q, const int warp_row_stride) {
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
apply_mask_local</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset_,
max_seqlen_q, warp_row_stride, -1, 0);
}
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
inline __device__ void apply_mask_causal_w_idx(
Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset_)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 2, "Only support 2D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol));
CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0)));
#pragma unroll
for (int ni = 0; ni < size<1, 1>(tensor); ++ni) {
if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) {
tensor(mi, ni) = -INFINITY;
}
}
// if (cute::thread0()) {
// printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k);
// print(tensor(_, make_coord(j, ni)));
// // print(tensor(_, j + ni * size<1, 0>(tensor)));
// }
}
}
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
inline __device__ void apply_dropout(Tensor<Engine, Layout> &tensor, uint8_t p_dropout_in_uint8_t,
unsigned long long seed, unsigned long long offset,
int block_row_start, int block_col_start,
int block_row_stride) {
// tensor has shape (8, MMA_M, MMA_N / 2)
using T = typename Engine::value_type;
auto encode_dropout = [](bool keep, T val) {
return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0));
}; };
static_assert(decltype(size<2>(tensor))::value % 2 == 0);
const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t); template<bool Is_dropout=false, bool Split=false, typename Tensor0>
const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t); __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
// if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); } SumOp<float> sum_op;
#pragma unroll quad_allreduce_(row_sum, row_sum, sum_op);
for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) { TensorT lse = make_fragment_like(row_sum);
uint2 rowcol = make_uint2(block_row_start, block_col_start); Tensor acc_o_rowcol = make_tensor(acc_o.data(), pytorch_flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll #pragma unroll
for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) { for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
// if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));} float sum = row_sum(mi);
uint4 random_uint4 = pytorch_flash::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset); float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
// if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);} lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4); float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
// Special implementation for 16-bit types: we duplicate the threshold to the #pragma unroll
// low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
// to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000,
// and the high 16 bits will be either 0xffff or 0x0000, depending on whether
// the random value is less than the threshold.
// We then do a bit-wise AND between the mask and the original value (in 32-bit).
// We're exploiting the fact that floating point comparison is equivalent to integer
// comparison, since we're comparing unsigned integers whose top 8-bits are zero.
if (!encode_dropout_in_sign_bit
&& (std::is_same<T, cutlass::half_t>::value || std::is_same<T, cutlass::bfloat16_t>::value)) {
uint16_t rnd_16[16];
#pragma unroll
for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); }
uint32_t (&rnd_32)[8] = reinterpret_cast<uint32_t (&)[8]>(rnd_16);
#pragma unroll
for (int j = 0; j < 2; j++) {
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
// if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); }
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
#pragma unroll
for (int i = 0; i < 4; i++) {
uint32_t mask;
asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t));
tensor_uint32(i) &= mask;
}
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
}
} else {
#pragma unroll
for (int j = 0; j < 2; j++) {
#pragma unroll
for (int i = 0; i < 8; i++) {
tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j));
}
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
}
}
// // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w);
// // }
} }
} return lse;
} };
};
} // namespace pytorch_flash } // namespace pytorch_flash

View File

@ -14,6 +14,7 @@
/// some_function<BoolConst>(...); /// some_function<BoolConst>(...);
/// }); /// });
/// ``` /// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...) \ #define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \ [&] { \
if (COND) { \ if (COND) { \
@ -25,6 +26,46 @@
} \ } \
}() }()
#ifdef FLASHATTENTION_DISABLE_DROPOUT
#define DROPOUT_SWITCH(COND, CONST_NAME, ...) \
[&] { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
}()
#else
#define DROPOUT_SWITCH BOOL_SWITCH
#endif
#ifdef FLASHATTENTION_DISABLE_ALIBI
#define ALIBI_SWITCH(COND, CONST_NAME, ...) \
[&] { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
}()
#else
#define ALIBI_SWITCH BOOL_SWITCH
#endif
#ifdef FLASHATTENTION_DISABLE_UNEVEN_K
#define EVENK_SWITCH(COND, CONST_NAME, ...) \
[&] { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
}()
#else
#define EVENK_SWITCH BOOL_SWITCH
#endif
#ifdef FLASHATTENTION_DISABLE_LOCAL
#define LOCAL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
}()
#else
#define LOCAL_SWITCH BOOL_SWITCH
#endif
#define FP16_SWITCH(COND, ...) \ #define FP16_SWITCH(COND, ...) \
[&] { \ [&] { \
if (COND) { \ if (COND) { \
@ -36,7 +77,7 @@
} \ } \
}() }()
#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \ #define HEADDIM_SWITCH(HEADDIM, ...) \
[&] { \ [&] { \
if (HEADDIM <= 32) { \ if (HEADDIM <= 32) { \
constexpr static int kHeadDim = 32; \ constexpr static int kHeadDim = 32; \

View File

@ -22,16 +22,17 @@
#include <cutlass/numeric_conversion.h> #include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h> #include <cutlass/numeric_types.h>
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
namespace pytorch_flash { namespace pytorch_flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
inline __device__ uint32_t relu2(const uint32_t x); __forceinline__ __device__ uint32_t relu2(const uint32_t x);
template<> template<>
inline __device__ uint32_t relu2<cutlass::half_t>(const uint32_t x) { __forceinline__ __device__ uint32_t relu2<cutlass::half_t>(const uint32_t x) {
uint32_t res; uint32_t res;
const uint32_t zero = 0u; const uint32_t zero = 0u;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
@ -49,7 +50,7 @@ inline __device__ uint32_t relu2<cutlass::half_t>(const uint32_t x) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template<> template<>
inline __device__ uint32_t relu2<cutlass::bfloat16_t>(const uint32_t x) { __forceinline__ __device__ uint32_t relu2<cutlass::bfloat16_t>(const uint32_t x) {
uint32_t res; uint32_t res;
const uint32_t zero = 0u; const uint32_t zero = 0u;
asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
@ -62,10 +63,10 @@ inline __device__ uint32_t relu2<cutlass::bfloat16_t>(const uint32_t x) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template<typename T> template<typename T>
inline __device__ uint32_t convert_relu2(const float2 x); __forceinline__ __device__ uint32_t convert_relu2(const float2 x);
template<> template<>
inline __device__ uint32_t convert_relu2<cutlass::half_t>(const float2 x) { __forceinline__ __device__ uint32_t convert_relu2<cutlass::half_t>(const float2 x) {
uint32_t res; uint32_t res;
const uint32_t a = reinterpret_cast<const uint32_t&>(x.x); const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);
const uint32_t b = reinterpret_cast<const uint32_t&>(x.y); const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);
@ -74,7 +75,7 @@ inline __device__ uint32_t convert_relu2<cutlass::half_t>(const float2 x) {
} }
template<> template<>
inline __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) { __forceinline__ __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
uint32_t res; uint32_t res;
const uint32_t a = reinterpret_cast<const uint32_t&>(x.x); const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);
const uint32_t b = reinterpret_cast<const uint32_t&>(x.y); const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);
@ -88,20 +89,20 @@ inline __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
template<typename T> template<typename T>
struct MaxOp { struct MaxOp {
__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; } __device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
}; };
template <> template <>
struct MaxOp<float> { struct MaxOp<float> {
// This is slightly faster // This is slightly faster
__device__ inline float operator()(float const &x, float const &y) { return max(x, y); } __device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
struct SumOp { struct SumOp {
__device__ inline T operator()(T const & x, T const & y) { return x + y; } __device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
@ -110,7 +111,7 @@ template<int THREADS>
struct Allreduce { struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template<typename T, typename Operator> template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) { static __device__ __forceinline__ T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2; constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op); return Allreduce<OFFSET>::run(x, op);
@ -122,7 +123,7 @@ struct Allreduce {
template<> template<>
struct Allreduce<2> { struct Allreduce<2> {
template<typename T, typename Operator> template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) { static __device__ __forceinline__ T run(T x, Operator &op) {
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x; return x;
} }
@ -134,7 +135,7 @@ template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename
typename Tensor2, typename Tensor3, typename Tensor4, typename Tensor2, typename Tensor3, typename Tensor4,
typename TiledMma, typename TiledCopyA, typename TiledCopyB, typename TiledMma, typename TiledCopyA, typename TiledCopyB,
typename ThrCopyA, typename ThrCopyB> typename ThrCopyA, typename ThrCopyB>
inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, __forceinline__ __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
Tensor4 const& tCsB, TiledMma tiled_mma, Tensor4 const& tCsB, TiledMma tiled_mma,
TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B,
ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) {
@ -161,9 +162,9 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3, template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
typename TiledMma, typename TiledCopy, typename ThrCopy> typename TiledMma, typename TiledCopy, typename ThrCopy>
inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, __forceinline__ __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B) { ThrCopy smem_thr_copy_B) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
@ -183,42 +184,48 @@ inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) // Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
template<typename Layout> template<typename Layout>
inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { __forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
static_assert(decltype(size<0>(acc_layout))::value == 4); static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3); static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
// TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
// "int_tuple.hpp(74): error: conversion to inaccessible base class"
// return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l)));
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) // Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. // if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8.
template<typename MMA_traits, typename Layout> template<typename MMA_traits, typename Layout>
inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { __forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) {
using X = Underscore; using X = Underscore;
static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2); static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2); static_assert(decltype(rank(acc_layout))::value == 3);
constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
static_assert(mma_shape_K == 8 || mma_shape_K == 16); static_assert(mma_shape_K == 8 || mma_shape_K == 16);
constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2; if constexpr (mma_shape_K == 8) {
auto l = logical_divide(rowcol_layout, Shape<X, Shape<X, Int<MMA_N_divisor>>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2))) return acc_layout;
// TD [2023-08-13]: Same error as above on Cutlass 3.2 } else {
// return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
// get<0, 1>(l), return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
// get<1, 1, 1>(l)); }
return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))), };
get<1>(get<0>(l)),
get<1>(get<1>(get<1>(l)))); ////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
template<typename Layout>
__forceinline__ __device__ auto convert_layout_acc_dropout(Layout acc_layout) {
using X = Underscore;
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename To_type, typename Engine, typename Layout> template <typename To_type, typename Engine, typename Layout>
inline __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) { __forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type; using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value; constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op; cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
@ -230,7 +237,7 @@ inline __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Engine, typename Layout> template <typename Engine, typename Layout>
inline __device__ void relu_(Tensor<Engine, Layout> &tensor) { __forceinline__ __device__ void relu_(Tensor<Engine, Layout> &tensor) {
constexpr int numel = decltype(size(tensor))::value; constexpr int numel = decltype(size(tensor))::value;
static_assert(numel % 2 == 0); static_assert(numel % 2 == 0);
using value_t = typename Engine::value_type; using value_t = typename Engine::value_type;
@ -246,7 +253,7 @@ inline __device__ void relu_(Tensor<Engine, Layout> &tensor) {
// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction // On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction
template <typename To_type, typename Engine, typename Layout> template <typename To_type, typename Engine, typename Layout>
inline __device__ auto convert_type_relu(Tensor<Engine, Layout> const &tensor) { __forceinline__ __device__ auto convert_type_relu(Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type; using From_type = typename Engine::value_type;
static_assert(std::is_same_v<To_type, cutlass::half_t> || std::is_same_v<To_type, cutlass::bfloat16_t>); static_assert(std::is_same_v<To_type, cutlass::half_t> || std::is_same_v<To_type, cutlass::bfloat16_t>);
static_assert(std::is_same_v<float, From_type>); static_assert(std::is_same_v<float, From_type>);
@ -288,7 +295,7 @@ void cp_async_wait() {
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true, template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3> typename Engine2, typename Layout2, typename Engine3, typename Layout3>
inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S, __forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN, Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) { Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
@ -357,7 +364,7 @@ inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const
template <bool Is_even_K=true, template <bool Is_even_K=true,
typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3> typename Engine2, typename Layout2, typename Engine3, typename Layout3>
inline __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S, __forceinline__ __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN, Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
Tensor<Engine3, Layout3> const &predicate_K, Tensor<Engine3, Layout3> const &predicate_K,
const int max_MN=0, const int min_MN=0) { const int max_MN=0, const int min_MN=0) {
@ -384,137 +391,4 @@ inline __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S,
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool Is_even_K=true, bool Clear_OOB_K=true,
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
inline __device__ void copy_rotary_interleaved(Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D,
Tensor<Engine2, Layout2> const &Cos,
Tensor<Engine2, Layout2> const &Sin,
Tensor<Engine3, Layout3> const &identity_MN,
const int max_MN, const int min_MN,
const int dim, const int rotary_dim) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K
CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K
static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2);
static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
Tensor rCos = make_fragment_like(Cos);
Tensor rSin = make_fragment_like(Sin);
Tensor rS = make_fragment_like(S);
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
cute::copy(S(_, m, k), rS(_, m, k));
if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
cute::copy(Cos(_, m, k), rCos(_, m, k));
cute::copy(Sin(_, m, k), rSin(_, m, k));
Tensor S_fp32 = convert_type<float>(rS(_, m, k));
Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
#pragma unroll
for (int i = 0; i < size<0>(rS) / 2; ++i) {
float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i);
float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i);
S_fp32(2 * i) = real;
S_fp32(2 * i + 1) = imag;
}
// Idk but I need to copy for the convert_type to work
Tensor S_fp32_copy = make_fragment_like(S_fp32);
cute::copy(S_fp32, S_fp32_copy);
using T = typename Engine0::value_type;
Tensor S_og_type = convert_type<T>(S_fp32_copy);
cute::copy(S_og_type, rS(_, m, k));
}
cute::copy(rS(_, m, k), D(_, m, k));
} else if (Clear_OOB_K) {
cute::clear(D(_, m, k));
}
}
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool Is_even_K=true, bool Clear_OOB_K=true,
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
inline __device__ void copy_rotary_contiguous(Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D,
Tensor<Engine2, Layout2> const &Cos,
Tensor<Engine2, Layout2> const &Sin,
Tensor<Engine3, Layout3> const &identity_MN,
const int max_MN, const int min_MN,
const int dim, const int rotary_dim) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA
CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin));
static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
Tensor rCos = make_fragment_like(Cos);
Tensor rSin = make_fragment_like(Sin);
Tensor rS = make_fragment_like(S);
Tensor rS_other = make_fragment_like(rS(_, 0, 0));
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
cute::copy(S(_, m, k), rS(_, m, k));
if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2;
Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout());
cute::copy(gS_other, rS_other);
// if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); }
Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout());
Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout());
cute::copy(gCos, rCos(_, m, k));
cute::copy(gSin, rSin(_, m, k));
// if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); }
Tensor S_fp32 = convert_type<float>(rS(_, m, k));
Tensor S_other_fp32 = convert_type<float>(rS_other);
Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
#pragma unroll
for (int i = 0; i < size<0>(rS); ++i) {
S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i));
}
// Idk but I need to copy for the convert_type to work
Tensor S_fp32_copy = make_fragment_like(S_fp32);
cute::copy(S_fp32, S_fp32_copy);
using T = typename Engine0::value_type;
Tensor S_og_type = convert_type<T>(S_fp32_copy);
cute::copy(S_og_type, rS(_, m, k));
// if (cute::thread0()) { print_tensor(rS(_, m, k)); }
}
cute::copy(rS(_, m, k), D(_, m, k));
} else if (Clear_OOB_K) {
cute::clear(D(_, m, k));
}
}
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace pytorch_flash } // namespace pytorch_flash

View File

@ -242,7 +242,7 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
return true; return true;
} }
bool check_requires_grad_and_head_dim_gt192_and_sm_ge86_lt90( bool check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89(
sdp_params const& params, sdp_params const& params,
bool debug) { bool debug) {
// Flash Attention will raise an error in the backward pass if the head_dim // Flash Attention will raise an error in the backward pass if the head_dim
@ -252,11 +252,19 @@ bool check_requires_grad_and_head_dim_gt192_and_sm_ge86_lt90(
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm86_or_sm89 = check_sm_version<sm86, sm89>(dprops); bool is_sm86_or_sm89 = check_sm_version<sm86, sm89>(dprops);
bool is_head_dim_gt192 = params.query.sym_size(-1) > 192; bool is_head_dim_gt192 = params.query.sym_size(-1) > 192;
if (input_requires_grad(params) && is_sm86_or_sm89 && is_head_dim_gt192) { bool is_head_dim_lte224 = params.query.sym_size(-1) <= 224;
bool is_dropout = params.dropout > 0.0;
// head_dim size in (192, 224] is not supported on sm86 and sm89
bool cond1 = is_head_dim_gt192 && is_head_dim_lte224;
// head_dim size > 224 and is_dropout is not supported on sm86 and sm89
bool cond2 = params.query.sym_size(-1) > 224 && is_dropout;
if (input_requires_grad(params) && is_sm86_or_sm89 && (cond1 || cond2)) {
if (debug) { if (debug) {
TORCH_WARN( TORCH_WARN(
"Flash attention currently doesn't support training with head_dim greater than 192 on gpu architectures in the range[sm86, sm89].", "Flash attention currently doesn't support training with head_dim ∈ (192, 224] or "
"Attempting to run with head_dim: ", "(head_dim ∈ (224, 256] and dropout > 0.0) on gpu architectures in the range[sm86, sm89].",
"Attempting to run with dropout set to: ", params.dropout,
"and head_dim: ",
params.query.sym_size(-1), " on a sm ", dprops->major, ".", params.query.sym_size(-1), " on a sm ", dprops->major, ".",
dprops->minor, " gpu."); dprops->minor, " gpu.");
} }
@ -467,7 +475,7 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) {
check_for_attn_mask, check_for_attn_mask,
check_head_dim_size_flash, check_head_dim_size_flash,
check_flash_attention_hardware_support, check_flash_attention_hardware_support,
check_requires_grad_and_head_dim_gt192_and_sm_ge86_lt90, check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89,
check_flash_causal_non_square_seqlens, check_flash_causal_non_square_seqlens,
check_dtypes_low_precision); check_dtypes_low_precision);
for (auto& constraint : general_constraints) { for (auto& constraint : general_constraints) {

View File

@ -106,10 +106,11 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout, const float p_dropout,
const float softmax_scale, const float softmax_scale,
bool is_causal, bool is_causal,
const int window_size_left, int window_size_left,
int window_size_right, int window_size_right,
const bool return_softmax, const bool return_softmax,
c10::optional<at::Generator> gen_) { c10::optional<at::Generator> gen_) {
@ -311,13 +312,14 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1 const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used. c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
const int max_seqlen_q, c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
int max_seqlen_q,
const int max_seqlen_k, const int max_seqlen_k,
const float p_dropout, const float p_dropout,
const float softmax_scale, const float softmax_scale,
const bool zero_tensors, const bool zero_tensors,
const bool is_causal, bool is_causal,
const int window_size_left, int window_size_left,
int window_size_right, int window_size_right,
const bool return_softmax, const bool return_softmax,
c10::optional<at::Generator> gen_) { c10::optional<at::Generator> gen_) {
@ -343,11 +345,13 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout, // probability to drop const float p_dropout, // probability to drop
const float softmax_scale, const float softmax_scale,
const bool is_causal, const bool is_causal,
const int window_size_left, int window_size_left,
int window_size_right, int window_size_right,
const bool deterministic,
const at::Tensor philox_seed, const at::Tensor philox_seed,
const at::Tensor philox_offset) { const at::Tensor philox_offset) {
check_gpu_arch(); check_gpu_arch();
@ -630,14 +634,16 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1 const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
const int max_seqlen_q, const int max_seqlen_q,
const int max_seqlen_k, // max sequence length to choose the kernel const int max_seqlen_k, // max sequence length to choose the kernel
const float p_dropout, // probability to drop const float p_dropout, // probability to drop
const float softmax_scale, const float softmax_scale,
const bool zero_tensors, const bool zero_tensors,
const bool is_causal, const bool is_causal,
const int window_size_left, int window_size_left,
int window_size_right, int window_size_right,
const bool deterministic,
const at::Tensor philox_seed, const at::Tensor philox_seed,
const at::Tensor philox_offset) { const at::Tensor philox_offset) {
TORCH_CHECK(false, "mha_varlen_bwd not supported on ROCm"); TORCH_CHECK(false, "mha_varlen_bwd not supported on ROCm");

View File

@ -1328,13 +1328,17 @@ class TestSDPAFailureModes(NNTestCase):
_do_cuda_non_default_stream = True _do_cuda_non_default_stream = True
@onlyCUDA @onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION or not isSM8XDevice, @unittest.skipIf(
"Does not support fused SDPA or not SM86+ hardware") not PLATFORM_SUPPORTS_FLASH_ATTENTION or not isSM8XDevice,
"Does not support fused SDPA or not SM86+ hardware",
)
@parametrize("head_dim", [193, 204, 256]) @parametrize("head_dim", [193, 204, 256])
def test_flash_backward_failure_sm86plus(self, device, head_dim: int): @parametrize("dropout_p", [0.0, 0.2])
def test_flash_backward_failure_sm86plus(self, device, head_dim: int, dropout_p: float):
dtype = torch.float16 dtype = torch.float16
make_tensor = partial(torch.rand, device=device, dtype=dtype) make_tensor = partial(torch.rand, device=device, dtype=dtype)
# See check_requires_grad_and_head_dim_gt64_and_sm_ge86 in pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.h # See check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89 in
# pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.h
size = (2, 2, 4, head_dim) size = (2, 2, 4, head_dim)
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
@ -1351,8 +1355,15 @@ class TestSDPAFailureModes(NNTestCase):
q = make_tensor(size, requires_grad=True) q = make_tensor(size, requires_grad=True)
k = make_tensor(size, requires_grad=True) k = make_tensor(size, requires_grad=True)
v = make_tensor(size, requires_grad=True) v = make_tensor(size, requires_grad=True)
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( if 192 < head_dim <= 224 or (head_dim > 224 and dropout_p != 0.0):
q, k, v, None, 0.0, False)) self.assertRaises(
RuntimeError,
lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, dropout_p, False
),
)
else:
flash_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, None, dropout_p, False)
@onlyCUDA @onlyCUDA
def test_dispatch_fails_no_backend(self, device): def test_dispatch_fails_no_backend(self, device):
@ -1589,7 +1600,6 @@ class TestSDPAFailureModes(NNTestCase):
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, 0.0, False)) q, k, v, None, 0.0, False))
@onlyCUDA @onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION or not isLessThanSM80Device, @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION or not isLessThanSM80Device,
"Current platform does not support fused SDPA or is an SM80+ device.") "Current platform does not support fused SDPA or is an SM80+ device.")
@ -1670,37 +1680,35 @@ class TestSDPAFailureModes(NNTestCase):
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, 0.0, is_causal=True)) q, k, v, None, 0.0, is_causal=True))
def _get_block_size(device, head_dim, is_causal): def _get_block_size_n(device, head_dim, is_dropout, is_causal):
# This should match the block sizes in the CUDA kernel # This should match the block sizes in the CUDA kernel
# Mask is only interesting when we are setting dropout
is_dropout = True
assert head_dim <= 256 assert head_dim <= 256
major, minor = torch.cuda.get_device_capability(device) major, minor = torch.cuda.get_device_capability(device)
is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100) is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100)
is_sm80 = major == 8 and minor == 0 is_sm80 = major == 8 and minor == 0
is_sm90 = major == 9 and minor == 0 is_sm90 = major == 9 and minor == 0
if head_dim <= 32: if head_dim <= 32:
return 128, 128 return 128
if head_dim <= 64: if head_dim <= 64:
return (128, 128) if not is_dropout else (128, 64) return 128 if not is_dropout else 64
elif head_dim <= 96: elif head_dim <= 96:
return (64, 64) if (is_sm8x and is_causal) else (128, 64) return 64
elif head_dim <= 128: elif head_dim <= 128:
if is_sm8x: if is_sm8x:
return (64, 64) if (not is_dropout and is_causal) else (128, 32) return 64 if (not is_dropout and is_causal) else 32
else: else:
return 128, (64 if not is_dropout else 32) return 64 if not is_dropout else 32
elif head_dim <= 160: elif head_dim <= 160:
if is_sm8x: if is_sm8x:
return (128, 64) if not is_causal else (64, 64) return 64
else: else:
return 128, 32 return 32
elif head_dim <= 192: elif head_dim <= 192:
return (128, 64) if not is_dropout else (64, 64) return 64
elif head_dim <= 224: elif head_dim <= 224:
return (128, 64) if (is_sm80 or is_sm90) else (64, 64) return 64
elif head_dim <= 256: elif head_dim <= 256:
return (128, 64) if is_sm80 else (64, 64) return 64
def pad_last_dim(input_tensor, alignment_size, slice: bool = False): def pad_last_dim(input_tensor, alignment_size, slice: bool = False):
@ -1963,7 +1971,114 @@ class TestSDPACudaOnly(NNTestCase):
_do_cuda_memory_leak_check = True _do_cuda_memory_leak_check = True
_do_cuda_non_default_stream = True _do_cuda_non_default_stream = True
def convert_flash_attn_S_to_softmax(self, S, query_padding_mask, key_padding_mask, head_dim, causal=False): # TODO USED FOR TESTING THE SCORES, e.g. testing ALIBI we don't need this now
def normalize_flash_attn_S(
self,
attn_unnorm,
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
attn_bias=None,
is_dropout=False,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
scale=None,
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k, v: (batch_size, seqlen_k, nheads, head_dim)
key_padding_mask: (batch_size, seqlen_q)
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
Output:
softmax_lse: (batch_size, nheads, seqlen_q)
softmax_max: (batch_size, nheads, seqlen_q)
"""
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if causal:
window_size = (window_size[0], 0)
q, k, v = q.float(), k.float(), v.float()
_, seqlen_q, _, head_dim = q.shape
seqlen_k = k.shape[1]
b = q.shape[0]
from torch.nn.attention.bias import _calculate_scale
scale = _calculate_scale(head_dim, scale)
scores = torch.matmul(q.transpose(1, 2) * scale, k.permute(0, 2, 3, 1))
if key_padding_mask is not None:
scores.masked_fill_(~key_padding_mask.view(b, 1, 1, -1), float("-inf"))
if window_size[0] >= 0 or window_size[1] >= 0:
local_mask = self.construct_local_mask(
seqlen_q,
seqlen_k,
window_size,
query_padding_mask,
key_padding_mask,
q.device,
)
scores.masked_fill_(local_mask, float("-inf"))
if attn_bias is not None:
scores = scores + attn_bias.to(dtype=scores.dtype)
block_size_n = _get_block_size_n(scores.device, head_dim, is_dropout, causal)
scores_block = scores.split(block_size_n, dim=-1)
lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1)
lse = torch.logsumexp(lse_block, dim=-1)
# lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf
# so that when we do torch.exp(m - lse), we get 0.0 instead of NaN.
lse[lse == float("-inf")] = float("inf")
scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1)
cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1)
attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1)
attn_norm = torch.cat(
[
a * (torch.exp(m - lse)).unsqueeze(-1)
for a, m in zip(attn_unnorm_block, cummax_block)
],
dim=-1,
)
if query_padding_mask is not None:
attn_norm.masked_fill_(~query_padding_mask.view(b, 1, -1, 1), 0.0)
# attn_norm.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
return attn_norm.to(dtype=attn_unnorm.dtype)
def construct_local_mask(self, seqlen_q, seqlen_k, window_size, query_padding_mask, key_padding_mask, device):
# row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
row_idx = torch.arange(seqlen_q, device=device, dtype=torch.long).view(-1, 1)
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
sk = (
seqlen_k
if key_padding_mask is None
else key_padding_mask.sum(-1).view(-1, 1, 1, 1)
# else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
)
sq = (
seqlen_q
if query_padding_mask is None
else query_padding_mask.sum(-1).view(-1, 1, 1, 1)
# else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
)
if window_size[0] < 0:
return col_idx > row_idx + sk - sq + window_size[1]
else:
sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
return torch.logical_or(
col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
col_idx < row_idx + sk - sq - window_size[0],
)
def convert_flash_attn_S_to_softmax(
self,
S,
seqlen_q,
seqlen_k,
query_padding_mask,
key_padding_mask,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
):
"""FlashAttention stores the S matrix in a different way. """FlashAttention stores the S matrix in a different way.
Arguments: Arguments:
S: (batch_size, nheads, seqlen_q, seqlen_k) S: (batch_size, nheads, seqlen_q, seqlen_k)
@ -1972,53 +2087,45 @@ class TestSDPACudaOnly(NNTestCase):
""" """
if TEST_WITH_ROCM: if TEST_WITH_ROCM:
return S return S
b = S.shape[0]
b, h, seqlen_q, seqlen_k = S.shape
warps_n = 4
blocksize_m, blocksize_n = _get_block_size(S.device, head_dim, causal)
nblocks_m = (seqlen_q + blocksize_m - 1) // blocksize_m
nblocks_n = (seqlen_k + blocksize_n - 1) // blocksize_n
mmas_n = (blocksize_n + 16 - 1) // 16
# Reshape S using PyTorch native functions
S_flat = S.view(b, h, nblocks_m, blocksize_m, nblocks_n, blocksize_n)
S_flat = S_flat.permute(0, 1, 2, 4, 3, 5)
S_flat = S_flat.reshape(b, h, nblocks_m, nblocks_n, (blocksize_m * blocksize_n))
S_converted = S_flat.view(b, h, nblocks_m, nblocks_n, mmas_n, -1, warps_n, 8, 4, 2, 2, 2)
S_converted = S_converted.permute(0, 1, 2, 5, 6, 10, 7, 3, 4, 9, 8, 11)
S_converted = S_converted.reshape(b, h, (nblocks_m * S_converted.size(3) *
warps_n * 2 * 8), (nblocks_n * mmas_n * 2 * 4 * 2))
if causal: if causal:
causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=S.device), 1) window_size = (window_size[0], 0)
S_converted.masked_fill_(causal_mask, 0.0) seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:]
S_converted = S
if window_size[0] >= 0 or window_size[1] >= 0:
local_mask = self.construct_local_mask(
seqlen_q,
seqlen_k,
window_size,
query_padding_mask,
key_padding_mask,
S.device,
)
local_mask = F.pad(
local_mask,
(0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q),
value=True,
)
S_converted = S_converted.masked_fill(local_mask, 0.0)
# Need to zero out things not in attention_mask in case S was initialized with random values # Need to zero out things not in attention_mask in case S was initialized with random values
# and some of those values aren't overwritten. # and some of those values aren't overwritten.
seqlen_q_og = query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q seqlen_q_og = (
query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded
)
if query_padding_mask is not None: if query_padding_mask is not None:
if seqlen_q_og < seqlen_q: query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q_rounded - seqlen_q_og))
query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q - seqlen_q_og)) # S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
else: S_converted = S_converted.masked_fill(~query_padding_mask.view(b, 1, -1, 1), 0.0)
query_padding_mask = query_padding_mask[:, :seqlen_q]
q_mask_fill = ~query_padding_mask.view(query_padding_mask.shape[0], 1, query_padding_mask.shape[1], 1)
S_converted = S_converted.masked_fill(q_mask_fill, 0.0)
seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k
if key_padding_mask is not None: if key_padding_mask is not None:
if seqlen_k_og < seqlen_k: key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k_rounded - seqlen_k_og))
key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k - seqlen_k_og)) S_converted = S_converted.masked_fill(~key_padding_mask.view(b, 1, 1, -1), 0.0)
else: # S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0)
key_padding_mask = key_padding_mask[:, :seqlen_k] S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q_rounded))
k_mask_fill = ~key_padding_mask.view(key_padding_mask.shape[0], 1, 1, key_padding_mask.shape[1]) S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded))
S_converted = S_converted.masked_fill(k_mask_fill, 0.0) return S_converted[:, :, :seqlen_q, :seqlen_k]
if seqlen_q_og < seqlen_q:
S_converted = S_converted[:, :, :seqlen_q_og, :]
else:
S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q))
if seqlen_k_og < seqlen_k:
S_converted = S_converted[:, :, :, :seqlen_k_og]
else:
S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k))
return S_converted
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
@parametrize("mask_dim", [1, 2, 3, 4]) @parametrize("mask_dim", [1, 2, 3, 4])
@ -2370,28 +2477,29 @@ class TestSDPACudaOnly(NNTestCase):
query, key, value = make_tensor(shape), make_tensor(shape), make_tensor(shape) query, key, value = make_tensor(shape), make_tensor(shape), make_tensor(shape)
with use_deterministic_algorithims(True, warn_only=warn_only): with use_deterministic_algorithims(True, warn_only=warn_only):
# Note that this should swith to a testing version with we remove old context manager
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]): with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]):
assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA") @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
@parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA)
@parametrize("warn_only", [True, False]) @parametrize("warn_only", [True, False])
def test_mem_eff_backwards_throws_determinism_warning(self, device, warn_only): def test_fused_backwards_throws_determinism_warning(self, device, warn_only, fused_kernel):
batch_size, seq_len, num_heads, head_dim = 1, 64, 8, 64 batch_size, seq_len, num_heads, head_dim = 1, 64, 8, 64
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim) shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=torch.float32, packed=False, requires_grad=True) make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=torch.float16, packed=False, requires_grad=True)
query, key, value = make_tensor(shape), make_tensor(shape), make_tensor(shape) query, key, value = make_tensor(shape), make_tensor(shape), make_tensor(shape)
kernel_name = "Memory Efficient attention" if fused_kernel == SDPBackend.EFFICIENT_ATTENTION else "Flash Attention"
warning_context = ( warning_context = (
self.assertWarnsRegex( self.assertWarnsRegex(
UserWarning, UserWarning,
"Memory Efficient attention defaults to a non-deterministic algorithm.", f"{kernel_name} defaults to a non-deterministic algorithm.",
) )
if warn_only if warn_only
else contextlib.nullcontext() else contextlib.nullcontext()
) )
with use_deterministic_algorithims(True, warn_only=warn_only): with use_deterministic_algorithims(True, warn_only=warn_only):
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): with sdpa_kernel(backends=[fused_kernel]):
with warning_context: with warning_context:
torch.nn.functional.scaled_dot_product_attention(query, key, value).sum().backward() torch.nn.functional.scaled_dot_product_attention(query, key, value).sum().backward()
@ -2710,8 +2818,6 @@ class TestSDPACudaOnly(NNTestCase):
is_dropout = dropout_p > 0.0 is_dropout = dropout_p > 0.0
if not is_dropout: if not is_dropout:
# Problem: We pad sizes in the composite region of the top level SDPA. But we need the
# Debug mask when have dropout. So I am going to manualy pad up here when testing dropout
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale) out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
with sdpa_kernel(backends=[SDPBackend.MATH]): with sdpa_kernel(backends=[SDPBackend.MATH]):
@ -2722,6 +2828,8 @@ class TestSDPACudaOnly(NNTestCase):
out_lp_ref = F.scaled_dot_product_attention( out_lp_ref = F.scaled_dot_product_attention(
query_ref_lp, key_ref_lp, value_ref_lp, is_causal=is_causal, scale=scale) query_ref_lp, key_ref_lp, value_ref_lp, is_causal=is_causal, scale=scale)
else: else:
# Problem: We pad sizes in the composite region of the top level SDPA. But we need the
# Debug mask when have dropout. So I am going to manualy pad up here when testing dropout
q_padded, q_og_size = pad_last_dim(query, 8) q_padded, q_og_size = pad_last_dim(query, 8)
k_padded, k_og_size = pad_last_dim(key, 8) k_padded, k_og_size = pad_last_dim(key, 8)
v_padded, v_og_size = pad_last_dim(value, 8) v_padded, v_og_size = pad_last_dim(value, 8)
@ -2740,9 +2848,14 @@ class TestSDPACudaOnly(NNTestCase):
batch_size, seq_len_k, device=device, dtype=torch.bool) batch_size, seq_len_k, device=device, dtype=torch.bool)
softmax_mask = self.convert_flash_attn_S_to_softmax( softmax_mask = self.convert_flash_attn_S_to_softmax(
dbug_mask, query_padding_mask, key_padding_mask, head_dim=head_dim, dbug_mask, seq_len_q, seq_len_k, query_padding_mask, key_padding_mask,
causal=is_causal)[:, :, :seq_len_q, :seq_len_k] causal=is_causal)[:, :, :seq_len_q, :seq_len_k]
dropout_mask = softmax_mask >= 0 dropout_mask = softmax_mask >= 0
# attn_unnorm = softmax_mask.abs()
# attn = self.normalize_flash_attn_S(attn_unnorm, q_padded,
# k_padded, v_padded, query_padding_mask,
# key_padding_mask, None, True, is_causal, scale=scale)
# High Precision Math Reference # High Precision Math Reference
out_ref = torch.ops.aten._scaled_dot_product_attention_math( out_ref = torch.ops.aten._scaled_dot_product_attention_math(
query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal, scale=scale, dropout_mask=dropout_mask)[0] query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal, scale=scale, dropout_mask=dropout_mask)[0]
@ -2823,7 +2936,8 @@ class TestSDPACudaOnly(NNTestCase):
batch_size, seq_len_k, device=device, dtype=torch.bool) batch_size, seq_len_k, device=device, dtype=torch.bool)
softmax_mask = self.convert_flash_attn_S_to_softmax( softmax_mask = self.convert_flash_attn_S_to_softmax(
dbug_mask, query_padding_mask, key_padding_mask, head_dim=head_dim, causal=is_causal) dbug_mask, seq_len_q, seq_len_k, query_padding_mask, key_padding_mask,
causal=is_causal)[:, :, :seq_len_q, :seq_len_k]
dropout_mask = softmax_mask >= 0 dropout_mask = softmax_mask >= 0
return dropout_mask return dropout_mask
@ -3178,7 +3292,7 @@ class TestSDPACudaOnly(NNTestCase):
key_padding_mask = key_padding_mask.to("cuda") key_padding_mask = key_padding_mask.to("cuda")
softmax_mask = self.convert_flash_attn_S_to_softmax( softmax_mask = self.convert_flash_attn_S_to_softmax(
dbug_mask, query_padding_mask, key_padding_mask, head_dim=head_dim, causal=is_causal) dbug_mask, max_seq_len_q, max_seq_len_kv, query_padding_mask, key_padding_mask, causal=is_causal)
dropout_mask = softmax_mask >= 0 dropout_mask = softmax_mask >= 0
nt_stack = [] nt_stack = []
for tensor_component in range(batch_size): for tensor_component in range(batch_size):